Skip to content

Commit 7d24dd4

Browse files
committed
allow disabling query model and provider
1 parent 18fdf3c commit 7d24dd4

File tree

7 files changed

+181
-0
lines changed

7 files changed

+181
-0
lines changed

src/app/endpoints/query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
get_agent,
3636
get_system_prompt,
3737
validate_conversation_ownership,
38+
validate_model_provider_override,
3839
)
3940
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
4041
from utils.transcripts import store_transcript
@@ -174,6 +175,9 @@ async def query_endpoint_handler(
174175
"""
175176
check_configuration_loaded(configuration)
176177

178+
# Enforce configuration: optionally disallow overriding model/provider in requests
179+
validate_model_provider_override(query_request, configuration)
180+
177181
# log Llama Stack configuration, but without sensitive information
178182
llama_stack_config = configuration.llama_stack_configuration.model_copy()
179183
llama_stack_config.api_key = "********"

src/app/endpoints/streaming_query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3434
from utils.transcripts import store_transcript
3535
from utils.types import TurnSummary
36+
from utils.endpoints import validate_model_provider_override
3637

3738
from app.endpoints.query import (
3839
get_rag_toolgroups,
@@ -548,6 +549,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
548549

549550
check_configuration_loaded(configuration)
550551

552+
# Enforce configuration: optionally disallow overriding model/provider in requests
553+
validate_model_provider_override(query_request, configuration)
554+
551555
# log Llama Stack configuration, but without sensitive information
552556
llama_stack_config = configuration.llama_stack_configuration.model_copy()
553557
llama_stack_config.api_key = "********"

src/models/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ class Customization(ConfigurationBase):
388388
"""Service customization."""
389389

390390
disable_query_system_prompt: bool = False
391+
disable_query_model_provider_override: bool = False
391392
system_prompt_path: Optional[FilePath] = None
392393
system_prompt: Optional[str] = None
393394

src/utils/endpoints.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,28 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:
8484
return constants.DEFAULT_SYSTEM_PROMPT
8585

8686

87+
def validate_model_provider_override(query_request: QueryRequest, config: AppConfig) -> None:
88+
"""Validate whether model/provider overrides are allowed.
89+
90+
Raises HTTP 422 if overrides are disabled and the request includes model or provider.
91+
"""
92+
disabled = (
93+
config.customization is not None
94+
and getattr(config.customization, "disable_query_model_provider_override", False)
95+
)
96+
if disabled and (query_request.model is not None or query_request.provider is not None):
97+
raise HTTPException(
98+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
99+
detail={
100+
"response": (
101+
"This instance does not support overriding model/provider in the query request "
102+
"(disable_query_model_provider_override is set). Please remove the model and provider "
103+
"fields from your request."
104+
)
105+
},
106+
)
107+
108+
87109
# # pylint: disable=R0913,R0917
88110
async def get_agent(
89111
client: AsyncLlamaStackClient,

tests/unit/app/endpoints/test_query.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,3 +1507,51 @@ def test_evaluate_model_hints(
15071507

15081508
assert provider_id == expected_provider
15091509
assert model_id == expected_model
1510+
1511+
1512+
@pytest.mark.asyncio
1513+
async def test_query_endpoint_rejects_model_provider_override_when_disabled(
1514+
mocker, dummy_request
1515+
):
1516+
"""Assert 422 and message when disable_query_model_provider_override is set and request includes model/provider."""
1517+
# Prepare configuration with the override disabled
1518+
config_dict = {
1519+
"name": "test",
1520+
"service": {
1521+
"host": "localhost",
1522+
"port": 8080,
1523+
"auth_enabled": False,
1524+
"workers": 1,
1525+
"color_log": True,
1526+
"access_log": True,
1527+
},
1528+
"llama_stack": {
1529+
"api_key": "test-key",
1530+
"url": "http://test.com:1234",
1531+
"use_as_library_client": False,
1532+
},
1533+
"user_data_collection": {"transcripts_enabled": False},
1534+
"mcp_servers": [],
1535+
"customization": {
1536+
"disable_query_model_provider_override": True,
1537+
},
1538+
}
1539+
cfg = AppConfig()
1540+
cfg.init_from_dict(config_dict)
1541+
1542+
# Patch endpoint configuration
1543+
mocker.patch("app.endpoints.query.configuration", cfg)
1544+
1545+
# Build a request that tries to override model/provider
1546+
query_request = QueryRequest(query="What?", model="m", provider="p")
1547+
1548+
with pytest.raises(HTTPException) as exc_info:
1549+
await query_endpoint_handler(
1550+
request=dummy_request, query_request=query_request, auth=MOCK_AUTH
1551+
)
1552+
1553+
assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
1554+
assert (
1555+
exc_info.value.detail["response"]
1556+
== "This instance does not support overriding model/provider in the query request (disable_query_model_provider_override is set). Please remove the model and provider fields from your request."
1557+
)

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,3 +1515,53 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
15151515
stream=True,
15161516
toolgroups=expected_toolgroups,
15171517
)
1518+
1519+
1520+
@pytest.mark.asyncio
1521+
async def test_streaming_query_endpoint_rejects_model_provider_override_when_disabled(mocker):
1522+
"""Assert 422 and message when disable_query_model_provider_override is set and request includes model/provider."""
1523+
# Prepare configuration with the override disabled
1524+
config_dict = {
1525+
"name": "test",
1526+
"service": {
1527+
"host": "localhost",
1528+
"port": 8080,
1529+
"auth_enabled": False,
1530+
"workers": 1,
1531+
"color_log": True,
1532+
"access_log": True,
1533+
},
1534+
"llama_stack": {
1535+
"api_key": "test-key",
1536+
"url": "http://test.com:1234",
1537+
"use_as_library_client": False,
1538+
},
1539+
"user_data_collection": {"transcripts_enabled": False},
1540+
"mcp_servers": [],
1541+
"customization": {
1542+
"disable_query_model_provider_override": True,
1543+
},
1544+
}
1545+
cfg = AppConfig()
1546+
cfg.init_from_dict(config_dict)
1547+
1548+
# Patch endpoint configuration
1549+
mocker.patch("app.endpoints.streaming_query.configuration", cfg)
1550+
1551+
# Build a query request that tries to override model/provider
1552+
query_request = QueryRequest(query="What?", model="m", provider="p")
1553+
1554+
request = Request(
1555+
scope={
1556+
"type": "http",
1557+
}
1558+
)
1559+
1560+
with pytest.raises(HTTPException) as exc_info:
1561+
await streaming_query_endpoint_handler(request, query_request, auth=MOCK_AUTH)
1562+
1563+
assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
1564+
assert (
1565+
exc_info.value.detail["response"]
1566+
== "This instance does not support overriding model/provider in the query request (disable_query_model_provider_override is set). Please remove the model and provider fields from your request."
1567+
)

tests/unit/utils/test_endpoints.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,55 @@ async def test_get_agent_no_tools_false_preserves_parser(
591591
tool_parser=mock_parser,
592592
enable_session_persistence=True,
593593
)
594+
595+
596+
@pytest.fixture(name="config_with_override_disabled")
597+
def config_with_override_disabled_fixture():
598+
"""Configuration where overriding model/provider is allowed (flag False)."""
599+
test_config = config_dict.copy()
600+
test_config["customization"] = {
601+
"disable_query_model_provider_override": False,
602+
}
603+
cfg = AppConfig()
604+
cfg.init_from_dict(test_config)
605+
return cfg
606+
607+
608+
@pytest.fixture(name="config_with_override_enabled")
609+
def config_with_override_enabled_fixture():
610+
"""Configuration where overriding model/provider is NOT allowed (flag True)."""
611+
test_config = config_dict.copy()
612+
test_config["customization"] = {
613+
"disable_query_model_provider_override": True,
614+
}
615+
cfg = AppConfig()
616+
cfg.init_from_dict(test_config)
617+
return cfg
618+
619+
620+
def test_validate_model_provider_override_allowed_when_flag_false(
621+
config_with_override_disabled,
622+
):
623+
"""Ensure no exception when overrides are allowed and request includes model/provider."""
624+
query_request = QueryRequest(query="q", model="m", provider="p")
625+
# Should not raise
626+
endpoints.validate_model_provider_override(query_request, config_with_override_disabled)
627+
628+
629+
def test_validate_model_provider_override_rejected_when_flag_true(
630+
config_with_override_enabled,
631+
):
632+
"""Ensure HTTP 422 when overrides are disabled and request includes model/provider."""
633+
query_request = QueryRequest(query="q", model="m", provider="p")
634+
with pytest.raises(HTTPException) as exc_info:
635+
endpoints.validate_model_provider_override(query_request, config_with_override_enabled)
636+
assert exc_info.value.status_code == 422
637+
638+
639+
def test_validate_model_provider_override_no_override_with_flag_true(
640+
config_with_override_enabled,
641+
):
642+
"""No exception when overrides are disabled but request does not include model/provider."""
643+
query_request = QueryRequest(query="q")
644+
# Should not raise
645+
endpoints.validate_model_provider_override(query_request, config_with_override_enabled)

0 commit comments

Comments
 (0)