Skip to content

Commit 3b0bc86

Browse files
authored
Merge pull request #763 from tisnik/lcore-740-minor-fixes
LCORE-740: minor fixes in endpoint unit tests
2 parents 4f39c33 + 1352820 commit 3b0bc86

File tree

4 files changed

+125
-90
lines changed

4 files changed

+125
-90
lines changed

tests/unit/app/endpoints/test_metrics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@ async def test_metrics_endpoint(mocker: MockerFixture) -> None:
1515
mock_authorization_resolvers(mocker)
1616

1717
mock_setup_metrics = mocker.patch(
18-
"app.endpoints.metrics.setup_model_metrics", return_value=None
18+
"app.endpoints.metrics.setup_model_metrics",
19+
new=mocker.AsyncMock(return_value=None),
1920
)
2021
request = Request(
2122
scope={
2223
"type": "http",
2324
}
2425
)
25-
auth: AuthTuple = ("test_user", "token", {})
26+
27+
# Authorization tuple required by URL endpoint handler
28+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
29+
2630
response = await metrics_endpoint_handler(auth=auth, request=request)
2731
assert response is not None
2832
assert response.status_code == 200

tests/unit/app/endpoints/test_models.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ async def test_models_endpoint_handler_configuration_not_loaded(
3434
"headers": [(b"authorization", b"Bearer invalid-token")],
3535
}
3636
)
37-
auth: AuthTuple = ("user_id", "user_name", "token")
37+
38+
# Authorization tuple required by URL endpoint handler
39+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
3840

3941
with pytest.raises(HTTPException) as e:
4042
await models_endpoint_handler(request=request, auth=auth)
@@ -87,7 +89,10 @@ async def test_models_endpoint_handler_improper_llama_stack_configuration(
8789
"headers": [(b"authorization", b"Bearer invalid-token")],
8890
}
8991
)
90-
auth: AuthTuple = ("test_user", "token", {})
92+
93+
# Authorization tuple required by URL endpoint handler
94+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
95+
9196
with pytest.raises(HTTPException) as e:
9297
await models_endpoint_handler(request=request, auth=auth)
9398
assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@@ -133,7 +138,9 @@ async def test_models_endpoint_handler_configuration_loaded(
133138
"headers": [(b"authorization", b"Bearer invalid-token")],
134139
}
135140
)
136-
auth: AuthTuple = ("test_user", "token", {})
141+
142+
# Authorization tuple required by URL endpoint handler
143+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
137144

138145
with pytest.raises(HTTPException) as e:
139146
await models_endpoint_handler(request=request, auth=auth)
@@ -177,7 +184,9 @@ async def test_models_endpoint_handler_unable_to_retrieve_models_list(
177184
# Mock the LlamaStack client
178185
mock_client = mocker.AsyncMock()
179186
mock_client.models.list.return_value = []
180-
mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client")
187+
mock_lsc = mocker.patch(
188+
"app.endpoints.models.AsyncLlamaStackClientHolder.get_client"
189+
)
181190
mock_lsc.return_value = mock_client
182191
mock_config = mocker.Mock()
183192
mocker.patch("app.endpoints.models.configuration", mock_config)
@@ -188,7 +197,10 @@ async def test_models_endpoint_handler_unable_to_retrieve_models_list(
188197
"headers": [(b"authorization", b"Bearer invalid-token")],
189198
}
190199
)
191-
auth: AuthTuple = ("test_user", "token", {})
200+
201+
# Authorization tuple required by URL endpoint handler
202+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
203+
192204
response = await models_endpoint_handler(request=request, auth=auth)
193205
assert response is not None
194206

@@ -242,7 +254,9 @@ async def test_models_endpoint_llama_stack_connection_error(
242254
"headers": [(b"authorization", b"Bearer invalid-token")],
243255
}
244256
)
245-
auth: AuthTuple = ("test_user", "token", {})
257+
258+
# Authorization tuple required by URL endpoint handler
259+
auth: AuthTuple = ("test_user_id", "test_user", True, "test_token")
246260

247261
with pytest.raises(HTTPException) as e:
248262
await models_endpoint_handler(request=request, auth=auth)

tests/unit/app/endpoints/test_query_v2.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# pylint: disable=redefined-outer-name, import-error
22
"""Unit tests for the /query (v2) REST API endpoint using Responses API."""
33

4+
from typing import Any
45
import pytest
6+
from pytest_mock import MockerFixture
57
from fastapi import HTTPException, status, Request
68

79
from llama_stack_client import APIConnectionError
@@ -24,7 +26,7 @@ def dummy_request() -> Request:
2426
return req
2527

2628

27-
def test_get_rag_tools():
29+
def test_get_rag_tools() -> None:
2830
"""Test get_rag_tools returns None for empty list and correct tool format for vector stores."""
2931
assert get_rag_tools([]) is None
3032

@@ -35,7 +37,7 @@ def test_get_rag_tools():
3537
assert tools[0]["max_num_results"] == 10
3638

3739

38-
def test_get_mcp_tools_with_and_without_token():
40+
def test_get_mcp_tools_with_and_without_token() -> None:
3941
"""Test get_mcp_tools generates correct tool definitions with and without auth tokens."""
4042
servers = [
4143
ModelContextProtocolServer(name="fs", url="http://localhost:3000"),
@@ -58,7 +60,7 @@ def test_get_mcp_tools_with_and_without_token():
5860

5961

6062
@pytest.mark.asyncio
61-
async def test_retrieve_response_no_tools_bypasses_tools(mocker):
63+
async def test_retrieve_response_no_tools_bypasses_tools(mocker: MockerFixture) -> None:
6264
"""Test that no_tools=True bypasses tool configuration and passes None to responses API."""
6365
mock_client = mocker.Mock()
6466
# responses.create returns a synthetic OpenAI-like response
@@ -94,7 +96,9 @@ async def test_retrieve_response_no_tools_bypasses_tools(mocker):
9496

9597

9698
@pytest.mark.asyncio
97-
async def test_retrieve_response_builds_rag_and_mcp_tools(mocker):
99+
async def test_retrieve_response_builds_rag_and_mcp_tools(
100+
mocker: MockerFixture,
101+
) -> None:
98102
"""Test that retrieve_response correctly builds RAG and MCP tools from configuration."""
99103
mock_client = mocker.Mock()
100104
response_obj = mocker.Mock()
@@ -137,7 +141,9 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(mocker):
137141

138142

139143
@pytest.mark.asyncio
140-
async def test_retrieve_response_parses_output_and_tool_calls(mocker):
144+
async def test_retrieve_response_parses_output_and_tool_calls(
145+
mocker: MockerFixture,
146+
) -> None:
141147
"""Test that retrieve_response correctly parses output content and tool calls from response."""
142148
mock_client = mocker.Mock()
143149

@@ -190,7 +196,7 @@ async def test_retrieve_response_parses_output_and_tool_calls(mocker):
190196

191197

192198
@pytest.mark.asyncio
193-
async def test_retrieve_response_with_usage_info(mocker):
199+
async def test_retrieve_response_with_usage_info(mocker: MockerFixture) -> None:
194200
"""Test that token usage is extracted when provided by the API as an object."""
195201
mock_client = mocker.Mock()
196202

@@ -231,7 +237,7 @@ async def test_retrieve_response_with_usage_info(mocker):
231237

232238

233239
@pytest.mark.asyncio
234-
async def test_retrieve_response_with_usage_dict(mocker):
240+
async def test_retrieve_response_with_usage_dict(mocker: MockerFixture) -> None:
235241
"""Test that token usage is extracted when provided by the API as a dict."""
236242
mock_client = mocker.Mock()
237243

@@ -268,7 +274,7 @@ async def test_retrieve_response_with_usage_dict(mocker):
268274

269275

270276
@pytest.mark.asyncio
271-
async def test_retrieve_response_with_empty_usage_dict(mocker):
277+
async def test_retrieve_response_with_empty_usage_dict(mocker: MockerFixture) -> None:
272278
"""Test that empty usage dict is handled gracefully."""
273279
mock_client = mocker.Mock()
274280

@@ -305,7 +311,7 @@ async def test_retrieve_response_with_empty_usage_dict(mocker):
305311

306312

307313
@pytest.mark.asyncio
308-
async def test_retrieve_response_validates_attachments(mocker):
314+
async def test_retrieve_response_validates_attachments(mocker: MockerFixture) -> None:
309315
"""Test that retrieve_response validates attachments and includes them in the input string."""
310316
mock_client = mocker.Mock()
311317
response_obj = mocker.Mock()
@@ -345,7 +351,9 @@ async def test_retrieve_response_validates_attachments(mocker):
345351

346352

347353
@pytest.mark.asyncio
348-
async def test_query_endpoint_handler_v2_success(mocker, dummy_request):
354+
async def test_query_endpoint_handler_v2_success(
355+
mocker: MockerFixture, dummy_request: Request
356+
) -> None:
349357
"""Test successful query endpoint handler execution with proper response structure."""
350358
# Mock configuration to avoid configuration not loaded errors
351359
mock_config = mocker.Mock()
@@ -396,15 +404,18 @@ async def test_query_endpoint_handler_v2_success(mocker, dummy_request):
396404

397405

398406
@pytest.mark.asyncio
399-
async def test_query_endpoint_handler_v2_api_connection_error(mocker, dummy_request):
407+
async def test_query_endpoint_handler_v2_api_connection_error(
408+
mocker: MockerFixture, dummy_request: Request
409+
) -> None:
400410
"""Test that query endpoint handler properly handles and reports API connection errors."""
401411
# Mock configuration to avoid configuration not loaded errors
402412
mock_config = mocker.Mock()
403413
mock_config.llama_stack_configuration = mocker.Mock()
404414
mocker.patch("app.endpoints.query_v2.configuration", mock_config)
405415

406-
def _raise(*_args, **_kwargs):
407-
raise APIConnectionError(request=None)
416+
def _raise(*_args: Any, **_kwargs: Any) -> Exception:
417+
request = Request(scope={"type": "http"})
418+
raise APIConnectionError(request=request) # type: ignore
408419

409420
mocker.patch("client.AsyncLlamaStackClientHolder.get_client", side_effect=_raise)
410421

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -233,79 +233,85 @@ async def _test_streaming_query_endpoint_handler(
233233
# We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing
234234
# attribute and therefore makes checks to see whether it is missing fail.
235235
mock_streaming_response = mocker.AsyncMock()
236-
mock_streaming_response.__aiter__.return_value = [
237-
AgentTurnResponseStreamChunk(
238-
event=TurnResponseEvent(
239-
payload=AgentTurnResponseStepProgressPayload(
240-
event_type="step_progress",
241-
step_type="inference",
242-
delta=TextDelta(text="LLM ", type="text"),
243-
step_id="s1",
236+
mock_streaming_response.__aiter__.return_value = iter(
237+
[
238+
AgentTurnResponseStreamChunk(
239+
event=TurnResponseEvent(
240+
payload=AgentTurnResponseStepProgressPayload(
241+
event_type="step_progress",
242+
step_type="inference",
243+
delta=TextDelta(text="LLM ", type="text"),
244+
step_id="s1",
245+
)
244246
)
245-
)
246-
),
247-
AgentTurnResponseStreamChunk(
248-
event=TurnResponseEvent(
249-
payload=AgentTurnResponseStepProgressPayload(
250-
event_type="step_progress",
251-
step_type="inference",
252-
delta=TextDelta(text="answer", type="text"),
253-
step_id="s2",
247+
),
248+
AgentTurnResponseStreamChunk(
249+
event=TurnResponseEvent(
250+
payload=AgentTurnResponseStepProgressPayload(
251+
event_type="step_progress",
252+
step_type="inference",
253+
delta=TextDelta(text="answer", type="text"),
254+
step_id="s2",
255+
)
254256
)
255-
)
256-
),
257-
AgentTurnResponseStreamChunk(
258-
event=TurnResponseEvent(
259-
payload=AgentTurnResponseStepCompletePayload(
260-
event_type="step_complete",
261-
step_id="s1",
262-
step_type="tool_execution",
263-
step_details=ToolExecutionStep(
264-
turn_id="t1",
265-
step_id="s3",
257+
),
258+
AgentTurnResponseStreamChunk(
259+
event=TurnResponseEvent(
260+
payload=AgentTurnResponseStepCompletePayload(
261+
event_type="step_complete",
262+
step_id="s1",
266263
step_type="tool_execution",
267-
tool_responses=[
268-
ToolResponse(
269-
call_id="t1",
270-
tool_name="knowledge_search",
264+
step_details=ToolExecutionStep(
265+
turn_id="t1",
266+
step_id="s3",
267+
step_type="tool_execution",
268+
tool_responses=[
269+
ToolResponse(
270+
call_id="t1",
271+
tool_name="knowledge_search",
272+
content=[
273+
TextContentItem(text=s, type="text")
274+
for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS
275+
],
276+
)
277+
],
278+
tool_calls=[
279+
ToolCall(
280+
call_id="t1",
281+
tool_name="knowledge_search",
282+
arguments={},
283+
)
284+
],
285+
),
286+
)
287+
)
288+
),
289+
AgentTurnResponseStreamChunk(
290+
event=TurnResponseEvent(
291+
payload=AgentTurnResponseTurnCompletePayload(
292+
event_type="turn_complete",
293+
turn=Turn(
294+
turn_id="t1",
295+
input_messages=[],
296+
output_message=CompletionMessage(
297+
role="assistant",
271298
content=[
272-
TextContentItem(text=s, type="text")
273-
for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS
299+
TextContentItem(text="LLM answer", type="text")
274300
],
275-
)
276-
],
277-
tool_calls=[
278-
ToolCall(
279-
call_id="t1", tool_name="knowledge_search", arguments={}
280-
)
281-
],
282-
),
283-
)
284-
)
285-
),
286-
AgentTurnResponseStreamChunk(
287-
event=TurnResponseEvent(
288-
payload=AgentTurnResponseTurnCompletePayload(
289-
event_type="turn_complete",
290-
turn=Turn(
291-
turn_id="t1",
292-
input_messages=[],
293-
output_message=CompletionMessage(
294-
role="assistant",
295-
content=[TextContentItem(text="LLM answer", type="text")],
296-
stop_reason="end_of_turn",
297-
tool_calls=[],
301+
stop_reason="end_of_turn",
302+
tool_calls=[],
303+
),
304+
session_id="test_session_id",
305+
started_at=datetime.now(),
306+
steps=[],
307+
completed_at=datetime.now(),
308+
output_attachments=[],
298309
),
299-
session_id="test_session_id",
300-
started_at=datetime.now(),
301-
steps=[],
302-
completed_at=datetime.now(),
303-
output_attachments=[],
304-
),
310+
)
305311
)
306-
)
307-
),
308-
]
312+
),
313+
]
314+
)
309315

310316
mock_store_in_cache = mocker.patch(
311317
"app.endpoints.streaming_query.store_conversation_into_cache"
@@ -349,13 +355,13 @@ async def _test_streaming_query_endpoint_handler(
349355
assert isinstance(response, StreamingResponse)
350356

351357
# Collect the streaming response content
352-
streaming_content = []
358+
streaming_content: list[str] = []
353359
# response.body_iterator is an async generator, iterate over it directly
354360
async for chunk in response.body_iterator:
355-
streaming_content.append(chunk)
361+
streaming_content.append(str(chunk))
356362

357363
# Convert to string for assertions
358-
full_content = "".join(streaming_content) # type: ignore
364+
full_content = "".join(streaming_content)
359365

360366
# Assert the streaming content contains expected SSE format
361367
assert "data: " in full_content

0 commit comments

Comments
 (0)