Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ customization:
disable_query_system_prompt: true
```

### Control model/provider overrides via authorization

By default, clients may specify `model` and `provider` in `/v1/query` and `/v1/streaming_query`. Override is permitted only to callers granted the `MODEL_OVERRIDE` action via the authorization rules. Requests that include `model` or `provider` without this permission are rejected with HTTP 403.

## Safety Shields

A single Llama Stack configuration file can include multiple safety shields, which are utilized in agent
Expand Down
3 changes: 2 additions & 1 deletion docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,8 @@
"get_models",
"get_metrics",
"get_config",
"info"
"info",
"model_override"
],
"title": "Action",
"description": "Available actions in the system."
Expand Down
4 changes: 4 additions & 0 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
get_agent,
get_system_prompt,
validate_conversation_ownership,
validate_model_provider_override,
)
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.transcripts import store_transcript
Expand Down Expand Up @@ -174,6 +175,9 @@ async def query_endpoint_handler(
"""
check_configuration_loaded(configuration)

# Enforce RBAC: optionally disallow overriding model/provider in requests
validate_model_provider_override(query_request, request.state.authorized_actions)

# log Llama Stack configuration, but without sensitive information
llama_stack_config = configuration.llama_stack_configuration.model_copy()
llama_stack_config.api_key = "********"
Expand Down
4 changes: 4 additions & 0 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.transcripts import store_transcript
from utils.types import TurnSummary
from utils.endpoints import validate_model_provider_override

from app.endpoints.query import (
get_rag_toolgroups,
Expand Down Expand Up @@ -548,6 +549,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals

check_configuration_loaded(configuration)

# Enforce RBAC: optionally disallow overriding model/provider in requests
validate_model_provider_override(query_request, request.state.authorized_actions)

# log Llama Stack configuration, but without sensitive information
llama_stack_config = configuration.llama_stack_configuration.model_copy()
llama_stack_config.api_key = "********"
Expand Down
23 changes: 15 additions & 8 deletions src/authorization/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import wraps, lru_cache
from typing import Any, Callable, Tuple
from fastapi import HTTPException, status
from starlette.requests import Request

from authorization.resolvers import (
AccessResolver,
Expand Down Expand Up @@ -64,7 +65,9 @@ def get_authorization_resolvers() -> Tuple[RolesResolver, AccessResolver]:
)


async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -> None:
async def _perform_authorization_check(
action: Action, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> None:
"""Perform authorization check - common logic for all decorators."""
role_resolver, access_resolver = get_authorization_resolvers()

Expand Down Expand Up @@ -93,12 +96,16 @@ async def _perform_authorization_check(action: Action, kwargs: dict[str, Any]) -

authorized_actions = access_resolver.get_actions(user_roles)

try:
request = kwargs["request"]
request.state.authorized_actions = authorized_actions
except KeyError:
# This endpoint doesn't seem care about the authorized actions, so no need to set it
pass
req: Request | None = None
if "request" in kwargs and isinstance(kwargs["request"], Request):
req = kwargs["request"]
else:
for arg in args:
if isinstance(arg, Request):
req = arg
break
Comment on lines +100 to +106
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

if req is not None:
req.state.authorized_actions = authorized_actions


def authorize(action: Action) -> Callable:
Expand All @@ -107,7 +114,7 @@ def authorize(action: Action) -> Callable:
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
await _perform_authorization_check(action, kwargs)
await _perform_authorization_check(action, args, kwargs)
return await func(*args, **kwargs)

return wrapper
Expand Down
2 changes: 2 additions & 0 deletions src/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ class Action(str, Enum):
GET_CONFIG = "get_config"

INFO = "info"
# Allow overriding model/provider via request
MODEL_OVERRIDE = "model_override"


class AccessRule(ConfigurationBase):
Expand Down
24 changes: 24 additions & 0 deletions src/utils/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import constants
from models.requests import QueryRequest
from models.database.conversations import UserConversation
from models.config import Action
from app.database import get_session
from configuration import AppConfig
from utils.suid import get_suid
Expand Down Expand Up @@ -84,6 +85,29 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:
return constants.DEFAULT_SYSTEM_PROMPT


def validate_model_provider_override(
query_request: QueryRequest, authorized_actions: set[Action] | frozenset[Action]
) -> None:
"""Validate whether model/provider overrides are allowed by RBAC.

Raises HTTP 403 if the request includes model or provider and the caller
lacks Action.MODEL_OVERRIDE permission.
"""
if (query_request.model is not None or query_request.provider is not None) and (
Action.MODEL_OVERRIDE not in authorized_actions
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"response": (
"This instance does not permit overriding model/provider in the query request "
"(missing permission: MODEL_OVERRIDE). Please remove the model and provider "
"fields from your request."
)
},
)


# # pylint: disable=R0913,R0917
async def get_agent(
client: AsyncLlamaStackClient,
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from models.config import Action, ModelContextProtocolServer
from models.database.conversations import UserConversation
from utils.types import ToolCallSummary, TurnSummary
from authorization.resolvers import NoopRolesResolver

MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token")

Expand Down Expand Up @@ -1507,3 +1508,58 @@ def test_evaluate_model_hints(

assert provider_id == expected_provider
assert model_id == expected_model


@pytest.mark.asyncio
async def test_query_endpoint_rejects_model_provider_override_without_permission(
mocker, dummy_request
):
"""Assert 403 and message when request includes model/provider without MODEL_OVERRIDE."""
# Patch endpoint configuration (no need to set customization)
cfg = AppConfig()
cfg.init_from_dict(
{
"name": "test",
"service": {
"host": "localhost",
"port": 8080,
"auth_enabled": False,
"workers": 1,
"color_log": True,
"access_log": True,
},
"llama_stack": {
"api_key": "test-key",
"url": "http://test.com:1234",
"use_as_library_client": False,
},
"user_data_collection": {"transcripts_enabled": False},
"mcp_servers": [],
}
)
mocker.patch("app.endpoints.query.configuration", cfg)

# Patch authorization to exclude MODEL_OVERRIDE from authorized actions
access_resolver = mocker.Mock()
access_resolver.check_access.return_value = True
access_resolver.get_actions.return_value = set(Action) - {Action.MODEL_OVERRIDE}
mocker.patch(
"authorization.middleware.get_authorization_resolvers",
return_value=(NoopRolesResolver(), access_resolver),
)

Comment on lines +1546 to +1550
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure the test uses restricted authorized_actions.
dummy_request fixture sets all actions; override it here so validate_model_provider_override takes the 403 path.

Apply this diff:

     mocker.patch(
         "authorization.middleware.get_authorization_resolvers",
         return_value=(NoopRolesResolver(), access_resolver),
     )
 
+    # Simulate middleware: restrict request actions for this test
+    dummy_request.state.authorized_actions = set(Action) - {Action.MODEL_OVERRIDE}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mocker.patch(
"authorization.middleware.get_authorization_resolvers",
return_value=(NoopRolesResolver(), access_resolver),
)
mocker.patch(
"authorization.middleware.get_authorization_resolvers",
return_value=(NoopRolesResolver(), access_resolver),
)
# Simulate middleware: restrict request actions for this test
dummy_request.state.authorized_actions = set(Action) - {Action.MODEL_OVERRIDE}
🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_query.py around lines 1547 to 1551, the test
currently inherits the dummy_request fixture which grants all actions; to force
validate_model_provider_override to hit the 403 branch override the request's
authorized_actions to a restricted set (e.g., empty set or a set that does not
include the required action for this endpoint). Update the test to either set
dummy_request.authorized_actions = set() (or the minimal disallowed set) before
invoking the endpoint, or use mocker.patch to return a resolver whose
authorized_actions lacks the needed permission, so the call follows the 403
path.

# Build a request that tries to override model/provider
query_request = QueryRequest(query="What?", model="m", provider="p")

with pytest.raises(HTTPException) as exc_info:
await query_endpoint_handler(
request=dummy_request, query_request=query_request, auth=MOCK_AUTH
)

expected_msg = (
"This instance does not permit overriding model/provider in the query request "
"(missing permission: MODEL_OVERRIDE). Please remove the model and provider "
"fields from your request."
)
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
assert exc_info.value.detail["response"] == expected_msg
61 changes: 60 additions & 1 deletion tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
)

from models.requests import QueryRequest, Attachment
from models.config import ModelContextProtocolServer
from models.config import ModelContextProtocolServer, Action
from authorization.resolvers import NoopRolesResolver
from utils.types import ToolCallSummary, TurnSummary

MOCK_AUTH = ("mock_user_id", "mock_username", "mock_token")
Expand Down Expand Up @@ -1515,3 +1516,61 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
stream=True,
toolgroups=expected_toolgroups,
)


@pytest.mark.asyncio
async def test_streaming_query_endpoint_rejects_model_provider_override_without_permission(
mocker,
):
"""Assert 403 when request includes model/provider without MODEL_OVERRIDE."""
cfg = AppConfig()
cfg.init_from_dict(
{
"name": "test",
"service": {
"host": "localhost",
"port": 8080,
"auth_enabled": False,
"workers": 1,
"color_log": True,
"access_log": True,
},
"llama_stack": {
"api_key": "test-key",
"url": "http://test.com:1234",
"use_as_library_client": False,
},
"user_data_collection": {"transcripts_enabled": False},
"mcp_servers": [],
}
)
mocker.patch("app.endpoints.streaming_query.configuration", cfg)

# Patch authorization to exclude MODEL_OVERRIDE from authorized actions
access_resolver = mocker.Mock()
access_resolver.check_access.return_value = True
access_resolver.get_actions.return_value = set(Action) - {Action.MODEL_OVERRIDE}
mocker.patch(
"authorization.middleware.get_authorization_resolvers",
return_value=(NoopRolesResolver(), access_resolver),
)

# Build a query request that tries to override model/provider
query_request = QueryRequest(query="What?", model="m", provider="p")

request = Request(
scope={
"type": "http",
}
)

with pytest.raises(HTTPException) as exc_info:
Comment on lines +1561 to +1567
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Set request.state.authorized_actions in the test to avoid AttributeError and ensure 403 path.
The handler reads request.state.authorized_actions; when calling the endpoint directly, middleware isn’t invoked. Without setting this, the test can fail or miss the intended 403 branch.

Apply this diff after creating the Request:

     request = Request(
         scope={
             "type": "http",
         }
     )
 
+    # Simulate middleware-provided actions for this test
+    request.state.authorized_actions = set(Action) - {Action.MODEL_OVERRIDE}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
request = Request(
scope={
"type": "http",
}
)
with pytest.raises(HTTPException) as exc_info:
request = Request(
scope={
"type": "http",
}
)
# Simulate middleware-provided actions for this test
request.state.authorized_actions = set(Action) - {Action.MODEL_OVERRIDE}
with pytest.raises(HTTPException) as exc_info:
🤖 Prompt for AI Agents
In tests/unit/app/endpoints/test_streaming_query.py around lines 1562-1568, the
test constructs a Request but doesn’t set request.state.authorized_actions so
the handler raises AttributeError or skips the 403 branch; after creating the
Request, set request.state.authorized_actions to the appropriate
empty/unauthorized value used by the app (e.g., an empty set or list) so the
attribute exists and the endpoint takes the 403 path.

await streaming_query_endpoint_handler(request, query_request, auth=MOCK_AUTH)

expected_msg = (
"This instance does not permit overriding model/provider in the query request "
"(missing permission: MODEL_OVERRIDE). Please remove the model and provider "
"fields from your request."
)
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
assert exc_info.value.detail["response"] == expected_msg
23 changes: 23 additions & 0 deletions tests/unit/utils/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tests.unit import config_dict

from models.requests import QueryRequest
from models.config import Action
from utils import endpoints
from utils.endpoints import get_agent

Expand Down Expand Up @@ -591,3 +592,25 @@ async def test_get_agent_no_tools_false_preserves_parser(
tool_parser=mock_parser,
enable_session_persistence=True,
)


def test_validate_model_provider_override_allowed_with_action():
"""Ensure no exception when caller has MODEL_OVERRIDE and request includes model/provider."""
query_request = QueryRequest(query="q", model="m", provider="p")
authorized_actions = {Action.MODEL_OVERRIDE}
endpoints.validate_model_provider_override(query_request, authorized_actions)


def test_validate_model_provider_override_rejected_without_action():
"""Ensure HTTP 403 when request includes model/provider and caller lacks permission."""
query_request = QueryRequest(query="q", model="m", provider="p")
authorized_actions: set[Action] = set()
with pytest.raises(HTTPException) as exc_info:
endpoints.validate_model_provider_override(query_request, authorized_actions)
assert exc_info.value.status_code == 403


def test_validate_model_provider_override_no_override_without_action():
"""No exception when request does not include model/provider regardless of permission."""
query_request = QueryRequest(query="q")
endpoints.validate_model_provider_override(query_request, set())