From 7ffe3fefcf946b1ad8679bb58196197045f8e5c0 Mon Sep 17 00:00:00 2001 From: Sundar Raghavan Date: Thu, 30 Oct 2025 19:36:24 -0700 Subject: [PATCH 1/7] feat: add middleware integration with namespace isolation and true immutability Add comprehensive middleware support to BedrockAgentCore SDK, enabling developers to inject cross-cutting concerns (authentication, logging, timing, metrics) into the request pipeline. --- src/bedrock_agentcore/runtime/__init__.py | 14 +- src/bedrock_agentcore/runtime/app.py | 98 +++--- src/bedrock_agentcore/runtime/context.py | 297 ++++++++++++++++-- tests/bedrock_agentcore/runtime/test_app.py | 9 + .../bedrock_agentcore/runtime/test_context.py | 124 ++++++++ 5 files changed, 483 insertions(+), 59 deletions(-) diff --git a/src/bedrock_agentcore/runtime/__init__.py b/src/bedrock_agentcore/runtime/__init__.py index 13107ff..098c7ba 100644 --- a/src/bedrock_agentcore/runtime/__init__.py +++ b/src/bedrock_agentcore/runtime/__init__.py @@ -4,10 +4,20 @@ - BedrockAgentCoreApp: Main application class - RequestContext: HTTP request context - BedrockAgentCoreContext: Agent identity context +- ProcessingContext +- StandardNamespaces """ from .app import BedrockAgentCoreApp -from .context import BedrockAgentCoreContext, RequestContext +from .context import AgentContext, BedrockAgentCoreContext, ProcessingContext, RequestContext, StandardNamespaces from .models import PingStatus -__all__ = ["BedrockAgentCoreApp", "RequestContext", "BedrockAgentCoreContext", "PingStatus"] +__all__ = [ + "BedrockAgentCoreApp", + "RequestContext", + "ProcessingContext", + "AgentContext", + "BedrockAgentCoreContext", + "PingStatus", + "StandardNamespaces", +] diff --git a/src/bedrock_agentcore/runtime/app.py b/src/bedrock_agentcore/runtime/app.py index 2b84056..2d51a46 100644 --- a/src/bedrock_agentcore/runtime/app.py +++ b/src/bedrock_agentcore/runtime/app.py @@ -20,14 +20,8 @@ from starlette.routing import Route from starlette.types import Lifespan -from .context import BedrockAgentCoreContext, RequestContext +from .context import AgentContext, BedrockAgentCoreContext, ProcessingContext, RequestContext from .models import ( - ACCESS_TOKEN_HEADER, - AUTHORIZATION_HEADER, - CUSTOM_HEADER_PREFIX, - OAUTH2_CALLBACK_URL_HEADER, - REQUEST_ID_HEADER, - SESSION_HEADER, TASK_ACTION_CLEAR_FORCED_STATUS, TASK_ACTION_FORCE_BUSY, TASK_ACTION_FORCE_HEALTHY, @@ -275,51 +269,79 @@ def complete_async_task(self, task_id: int) -> bool: self.logger.warning("Attempted to complete unknown task ID: %s", task_id) return False - def _build_request_context(self, request) -> RequestContext: - """Build request context and setup all context variables.""" - try: - headers = request.headers - request_id = headers.get(REQUEST_ID_HEADER) - if not request_id: - request_id = str(uuid.uuid4()) + def _build_request_context(self, request) -> AgentContext: + """Build request context from incoming request. - session_id = headers.get(SESSION_HEADER) - BedrockAgentCoreContext.set_request_context(request_id, session_id) + Args: + request: Starlette Request object - agent_identity_token = headers.get(ACCESS_TOKEN_HEADER) - if agent_identity_token: - BedrockAgentCoreContext.set_workload_access_token(agent_identity_token) + Returns: + AgentContext with both request and processing data + """ + try: + # Generate request ID + request_id = str(uuid.uuid4()) - oauth2_callback_url = headers.get(OAUTH2_CALLBACK_URL_HEADER) - if oauth2_callback_url: - BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url) + # Extract session ID from headers (case-insensitive) + session_id = request.headers.get("x-amzn-bedrock-agentcore-runtime-session-id") - # Collect relevant request headers (Authorization + Custom headers) + # Extract relevant headers (Authorization + Custom headers) request_headers = {} - # Add Authorization header if present - authorization_header = headers.get(AUTHORIZATION_HEADER) - if authorization_header is not None: - request_headers[AUTHORIZATION_HEADER] = authorization_header + # Create a case-insensitive lookup + headers_lower = {} + for key, value in request.headers.items(): + headers_lower[key.lower()] = (key, value) # Store original key + value + + # Check for Authorization header (case-insensitive) + if "authorization" in headers_lower: + original_key, value = headers_lower["authorization"] + request_headers["Authorization"] = value # Normalize to "Authorization" + + # Extract custom headers with the specific prefix (case-insensitive) + custom_prefix = "x-amzn-bedrock-agentcore-runtime-custom-" + for lower_key, (original_key, value) in headers_lower.items(): + if lower_key.startswith(custom_prefix): + request_headers[original_key] = value # Keep original case - # Add custom headers with the specified prefix - for header_name, header_value in headers.items(): - if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()): - request_headers[header_name] = header_value + # Convert empty dict to None + request_headers = request_headers if request_headers else None - # Set in context if any headers were found + # Set in BedrockAgentCoreContext for global access + BedrockAgentCoreContext.set_request_context(request_id, session_id) if request_headers: BedrockAgentCoreContext.set_request_headers(request_headers) - # Get the headers from context to pass to RequestContext - req_headers = BedrockAgentCoreContext.get_request_headers() + # Check if middleware injected processing data + processing_context = ProcessingContext() + if hasattr(request, "state") and hasattr(request.state, "processing_data"): + # Middleware has injected data + processing_data = request.state.processing_data + + # Handle both flat and namespaced data structures + if processing_data: + # Check if it's already namespaced (dict of dicts) + if all(isinstance(v, dict) for v in processing_data.values()): + # Already namespaced + processing_context.middleware_data = processing_data + else: + # Flat structure - put in 'default' namespace + processing_context.middleware_data["default"] = processing_data + + BedrockAgentCoreContext.set_processing_context(processing_context) + + # Create and return AgentContext + return AgentContext( + request=RequestContext(session_id=session_id, request_headers=request_headers), + processing=processing_context, + ) - return RequestContext(session_id=session_id, request_headers=req_headers) except Exception as e: self.logger.warning("Failed to build request context: %s: %s", type(e).__name__, e) - request_id = str(uuid.uuid4()) - BedrockAgentCoreContext.set_request_context(request_id, None) - return RequestContext(session_id=None) + # Return minimal context on error + return AgentContext( + request=RequestContext(session_id=None, request_headers=None), processing=ProcessingContext() + ) def _takes_context(self, handler: Callable) -> bool: try: diff --git a/src/bedrock_agentcore/runtime/context.py b/src/bedrock_agentcore/runtime/context.py index b32fbe1..4ac0cdb 100644 --- a/src/bedrock_agentcore/runtime/context.py +++ b/src/bedrock_agentcore/runtime/context.py @@ -3,36 +3,229 @@ Contains metadata extracted from HTTP requests that handlers can optionally access. """ +import re from contextvars import ContextVar -from typing import Dict, Optional +from types import MappingProxyType +from typing import Any, ClassVar, Dict, List, Mapping, Optional from pydantic import BaseModel, Field class RequestContext(BaseModel): - """Request context containing metadata from HTTP requests.""" + """Immutable request context containing metadata from HTTP requests. - session_id: Optional[str] = Field(None) - request_headers: Optional[Dict[str, str]] = Field(None) + This represents what the client sent and should never be modified. + Uses MappingProxyType via @property for truly immutable headers. + + Attributes: + session_id: Session identifier from X-Amzn-Bedrock-AgentCore-Runtime-Session-Id header + request_headers: Immutable headers sent by client (exposed as MappingProxyType) + """ + + session_id: Optional[str] = Field(None, description="Session identifier") + request_headers_dict: Optional[Dict[str, str]] = Field( + None, alias="request_headers", description="Internal headers storage" + ) + + class Config: + """Pydantic model configuration for immutability.""" + + frozen = True + arbitrary_types_allowed = True + + @property + def request_headers(self) -> Optional[Mapping[str, str]]: + """Return headers as immutable MappingProxyType. + + Returns: + MappingProxyType wrapping headers dict, or None if no headers + """ + if self.request_headers_dict is None: + return None + return MappingProxyType(self.request_headers_dict) + + +class ProcessingContext(BaseModel): + """Mutable processing context with namespace isolation. + + This is where middleware can inject data that handlers need to see. + Uses namespaces to prevent accidental collisions and make malicious overwrites obvious. + + Attributes: + middleware_data: Namespaced data injected by middleware layers + processing_metadata: Metadata about request processing (timestamps, traces, etc) + """ + + middleware_data: Dict[str, Dict[str, Any]] = Field( + default_factory=dict, description="Namespaced middleware-injected data" + ) + processing_metadata: Dict[str, Any] = Field(default_factory=dict, description="Processing metadata") + + # ClassVar for Pydantic v2 compatibility + DENY_PATTERNS: ClassVar[List[str]] = [ + r"\x00", + r"[\r\n].*?:", + r"\s]", + r"javascript:", + r"\$\{", + ] + + def set( + self, + namespace: str, + key: str, + value: Any, + max_size: int = 100_000, + validate: bool = False, + ) -> None: + """Set middleware data in a specific namespace with optional validation. + + Args: + namespace: Namespace to isolate this data (e.g., 'auth', 'metrics', 'timing') + key: Data key within the namespace + value: Data value (must be JSON-serializable) + max_size: Maximum size for string values (default 100KB) + validate: Whether to perform security validation (default False for performance) + + Raises: + ValueError: If validation is enabled and value fails checks + + Example: + context.processing.set('auth', 'user_id', 'user123') + """ + if validate and isinstance(value, str): + if len(value) > max_size: + raise ValueError(f"Value too large: {len(value)} bytes > {max_size} bytes") + + for pattern in self.DENY_PATTERNS: + if re.search(pattern, value, re.IGNORECASE): + raise ValueError("Potentially unsafe pattern detected in value") + + if namespace not in self.middleware_data: + self.middleware_data[namespace] = {} + + self.middleware_data[namespace][key] = value + + def get(self, namespace: str, key: str, default: Any = None) -> Any: + """Get middleware data from a specific namespace. + + Args: + namespace: Namespace to read from + key: Data key within the namespace + default: Default value if key not found + + Returns: + Value associated with key in namespace, or default + + Example: + user_id = context.processing.get('auth', 'user_id') + """ + return self.middleware_data.get(namespace, {}).get(key, default) + + def get_namespace(self, namespace: str) -> Dict[str, Any]: + """Get all data from a specific namespace. + + Args: + namespace: Namespace to retrieve + + Returns: + Dictionary of all key-value pairs in the namespace + + Example: + auth_data = context.processing.get_namespace('auth') + """ + return self.middleware_data.get(namespace, {}).copy() + + def has_namespace(self, namespace: str) -> bool: + """Check if a namespace exists. + + Args: + namespace: Namespace to check + + Returns: + True if namespace exists, False otherwise + """ + return namespace in self.middleware_data + + def add_metadata(self, key: str, value: Any) -> None: + """Add processing metadata for observability. + + Args: + key: Metadata key + value: Metadata value + + Example: + context.processing.add_metadata('handler_start_time', time.time()) + """ + self.processing_metadata[key] = value + + +class AgentContext(BaseModel): + """Combined context that handlers receive. + + Provides access to both immutable request data and mutable processing data. + + Attributes: + request: Immutable request data (what client sent) + processing: Mutable processing data (what middleware added, namespaced) + """ + + request: RequestContext + processing: ProcessingContext + + class Config: + """Pydantic model configuration.""" + + arbitrary_types_allowed = True + + @property + def session_id(self) -> Optional[str]: + """Session ID (backward compatibility property). + + Returns: + Session ID from request context + """ + return self.request.session_id + + @property + def request_headers(self) -> Optional[Mapping[str, str]]: + """Request headers (backward compatibility property). + + Returns: + Immutable request headers as MappingProxyType + """ + return self.request.request_headers class BedrockAgentCoreContext: - """Unified context manager for Bedrock AgentCore.""" + """Unified context manager for Bedrock AgentCore. - _workload_access_token: ContextVar[Optional[str]] = ContextVar("workload_access_token") - _oauth2_callback_url: ContextVar[Optional[str]] = ContextVar("oauth2_callback_url") - _request_id: ContextVar[Optional[str]] = ContextVar("request_id") - _session_id: ContextVar[Optional[str]] = ContextVar("session_id") - _request_headers: ContextVar[Optional[Dict[str, str]]] = ContextVar("request_headers") + Uses Python's contextvars for thread-safe, request-scoped storage. + """ + + _workload_access_token: ContextVar[Optional[str]] = ContextVar("workload_access_token", default=None) + _oauth2_callback_url: ContextVar[Optional[str]] = ContextVar("oauth2_callback_url", default=None) + _request_id: ContextVar[Optional[str]] = ContextVar("request_id", default=None) + _session_id: ContextVar[Optional[str]] = ContextVar("session_id", default=None) + _request_headers: ContextVar[Optional[Dict[str, str]]] = ContextVar("request_headers", default=None) + _processing_context: ContextVar[Optional[ProcessingContext]] = ContextVar("processing_context", default=None) @classmethod def set_workload_access_token(cls, token: str): - """Set the workload access token in the context.""" + """Set the workload access token in the context. + + Args: + token: Workload access token + """ cls._workload_access_token.set(token) @classmethod def get_workload_access_token(cls) -> Optional[str]: - """Get the workload access token from the context.""" + """Get the workload access token from the context. + + Returns: + Workload access token or None if not set + """ try: return cls._workload_access_token.get() except LookupError: @@ -40,12 +233,20 @@ def get_workload_access_token(cls) -> Optional[str]: @classmethod def set_oauth2_callback_url(cls, workload_callback_url: str): - """Set the oauth2 callback url in the context.""" + """Set the OAuth2 callback URL in the context. + + Args: + workload_callback_url: OAuth2 callback URL + """ cls._oauth2_callback_url.set(workload_callback_url) @classmethod def get_oauth2_callback_url(cls) -> Optional[str]: - """Get the oauth2 callback url from the context.""" + """Get the OAuth2 callback URL from the context. + + Returns: + OAuth2 callback URL or None if not set + """ try: return cls._oauth2_callback_url.get() except LookupError: @@ -53,13 +254,22 @@ def get_oauth2_callback_url(cls) -> Optional[str]: @classmethod def set_request_context(cls, request_id: str, session_id: Optional[str] = None): - """Set request-scoped identifiers.""" + """Set request-scoped identifiers. + + Args: + request_id: Unique request identifier + session_id: Optional session identifier + """ cls._request_id.set(request_id) cls._session_id.set(session_id) @classmethod def get_request_id(cls) -> Optional[str]: - """Get current request ID.""" + """Get current request ID. + + Returns: + Request ID or None if not set + """ try: return cls._request_id.get() except LookupError: @@ -67,7 +277,11 @@ def get_request_id(cls) -> Optional[str]: @classmethod def get_session_id(cls) -> Optional[str]: - """Get current session ID.""" + """Get current session ID. + + Returns: + Session ID or None if not set + """ try: return cls._session_id.get() except LookupError: @@ -75,13 +289,58 @@ def get_session_id(cls) -> Optional[str]: @classmethod def set_request_headers(cls, headers: Dict[str, str]): - """Set request headers in the context.""" + """Set request headers in the context. + + Args: + headers: Dictionary of request headers + """ cls._request_headers.set(headers) @classmethod def get_request_headers(cls) -> Optional[Dict[str, str]]: - """Get request headers from the context.""" + """Get request headers from the context. + + Returns: + Request headers dictionary or None if not set + """ try: return cls._request_headers.get() except LookupError: return None + + @classmethod + def set_processing_context(cls, context: ProcessingContext): + """Set processing context for current request. + + Args: + context: ProcessingContext instance with middleware data + """ + cls._processing_context.set(context) + + @classmethod + def get_processing_context(cls) -> Optional[ProcessingContext]: + """Get processing context for current request. + + Returns: + ProcessingContext instance or None if not set + """ + try: + return cls._processing_context.get() + except LookupError: + return None + + +class StandardNamespaces: + """Standard namespaces for common middleware patterns. + + Using these constants helps prevent typos and ensures consistency. + """ + + AUTH = "auth" + TIMING = "timing" + AUDIT = "audit" + METRICS = "metrics" + OBSERVABILITY = "observability" + FEATURE_FLAGS = "features" + RATE_LIMIT = "rate_limit" + CUSTOM = "custom" diff --git a/tests/bedrock_agentcore/runtime/test_app.py b/tests/bedrock_agentcore/runtime/test_app.py index c0c2568..65f86ba 100644 --- a/tests/bedrock_agentcore/runtime/test_app.py +++ b/tests/bedrock_agentcore/runtime/test_app.py @@ -1599,6 +1599,7 @@ def test_build_request_context_with_authorization_header(self): class MockRequest: def __init__(self): self.headers = {"Authorization": "Bearer test-auth-token", "Content-Type": "application/json"} + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) @@ -1619,6 +1620,7 @@ def __init__(self): "X-Other-Header": "should-not-include", "Content-Type": "application/json", } + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) @@ -1642,6 +1644,7 @@ def __init__(self): "Content-Type": "application/json", "X-Other-Header": "ignored", } + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) @@ -1712,6 +1715,7 @@ def __init__(self): "X-AMZN-BEDROCK-AGENTCORE-RUNTIME-CUSTOM-UPPERCASE": "upper-value", "X-Amzn-Bedrock-AgentCore-Runtime-Custom-MixedCase": "mixed-value", } + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) @@ -1736,6 +1740,7 @@ def __init__(self): "X-Amzn-Bedrock-AgentCore-Runtime-Request-Id": "test-request-123", "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id": "test-session-456", } + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) @@ -1839,6 +1844,7 @@ def __init__(self): "X-Amzn-Bedrock-AgentCore-Runtime-Custom-Spaces": "value with spaces", "X-Amzn-Bedrock-AgentCore-Runtime-Custom-Quotes": 'value-with-"quotes"', } + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) @@ -1866,6 +1872,7 @@ def __init__(self): # Prefix as substring - should NOT be included "PrefixX-Amzn-Bedrock-AgentCore-Runtime-Custom-": "has-prefix", } + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) @@ -1894,6 +1901,7 @@ def __init__(self): "Proxy-Authorization": "Bearer proxy-token", # Should NOT be included "X-Amzn-Bedrock-AgentCore-Runtime-Custom-Auth": "Bearer custom-token", # Should be included } + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) @@ -1924,6 +1932,7 @@ def __init__(self): "X-Amzn-Bedrock-AgentCore-Runtime-Custom-Empty": "", # Empty custom header "X-Amzn-Bedrock-AgentCore-Runtime-Custom-Valid": "valid-value", } + self.state = type("State", (), {})() mock_request = MockRequest() context = app._build_request_context(mock_request) diff --git a/tests/bedrock_agentcore/runtime/test_context.py b/tests/bedrock_agentcore/runtime/test_context.py index 61f41d5..41defc9 100644 --- a/tests/bedrock_agentcore/runtime/test_context.py +++ b/tests/bedrock_agentcore/runtime/test_context.py @@ -2,6 +2,8 @@ import contextvars +from pydantic import ValidationError + from bedrock_agentcore.runtime.context import BedrockAgentCoreContext, RequestContext @@ -183,3 +185,125 @@ def test_request_context_with_empty_headers(self): assert context.session_id == "test-session-789" assert context.request_headers == {} + + +class TestProcessingContext: + """Test ProcessingContext with namespace isolation.""" + + def test_processing_context_set_and_get(self): + """Test setting and getting namespaced data.""" + from bedrock_agentcore.runtime.context import ProcessingContext, StandardNamespaces + + context = ProcessingContext() + context.set(StandardNamespaces.AUTH, "user_id", "alice") + + result = context.get(StandardNamespaces.AUTH, "user_id") + assert result == "alice" + + def test_processing_context_namespace_isolation(self): + """Test that namespaces are isolated.""" + from bedrock_agentcore.runtime.context import ProcessingContext + + context = ProcessingContext() + context.set("auth", "user_id", "alice") + context.set("timing", "user_id", "bob") # Different namespace, same key + + assert context.get("auth", "user_id") == "alice" + assert context.get("timing", "user_id") == "bob" + + def test_processing_context_get_namespace(self): + """Test getting entire namespace.""" + from bedrock_agentcore.runtime.context import ProcessingContext + + context = ProcessingContext() + context.set("auth", "user_id", "alice") + context.set("auth", "role", "admin") + + auth_data = context.get_namespace("auth") + assert auth_data == {"user_id": "alice", "role": "admin"} + + def test_processing_context_has_namespace(self): + """Test checking namespace existence.""" + from bedrock_agentcore.runtime.context import ProcessingContext + + context = ProcessingContext() + assert not context.has_namespace("auth") + + context.set("auth", "user_id", "alice") + assert context.has_namespace("auth") + + def test_processing_context_validation(self): + """Test optional validation.""" + import pytest + + from bedrock_agentcore.runtime.context import ProcessingContext + + context = ProcessingContext() + + # Should work without validation + context.set("custom", "data", "x" * 200, validate=False) + + # Should fail with validation (exceeds max_size) + with pytest.raises(ValueError, match="Value too large"): + context.set("custom", "large", "x" * 200, validate=True, max_size=100) + + # Should fail with malicious pattern + with pytest.raises(ValueError, match="Potentially unsafe pattern"): + context.set("custom", "bad", "test\x00null", validate=True) + + +class TestAgentContext: + """Test AgentContext integration.""" + + def test_agent_context_backward_compatibility(self): + """Test backward compatibility properties.""" + from bedrock_agentcore.runtime.context import AgentContext, ProcessingContext, RequestContext + + request = RequestContext(session_id="test-123", request_headers={"Authorization": "Bearer token"}) + processing = ProcessingContext() + + context = AgentContext(request=request, processing=processing) + + # Backward compatibility + assert context.session_id == "test-123" + assert context.request_headers is not None + assert "Authorization" in context.request_headers + + def test_agent_context_immutability(self): + """Test that request context is immutable.""" + from types import MappingProxyType + + import pytest + + from bedrock_agentcore.runtime.context import AgentContext, ProcessingContext, RequestContext + + request = RequestContext(session_id="test-123", request_headers={"Authorization": "Bearer token"}) + processing = ProcessingContext() + context = AgentContext(request=request, processing=processing) + + # Cannot modify frozen request + with pytest.raises((ValidationError, AttributeError, TypeError)): + context.request.session_id = "hacked" + + # Cannot modify headers (MappingProxyType) + assert isinstance(context.request_headers, MappingProxyType) + with pytest.raises(TypeError): + context.request_headers["new"] = "value" + + # CAN modify processing context + context.processing.set("custom", "key", "value") + assert context.processing.get("custom", "key") == "value" + + +class TestStandardNamespaces: + """Test StandardNamespaces constants.""" + + def test_standard_namespaces_values(self): + """Test that standard namespace constants have expected values.""" + from bedrock_agentcore.runtime.context import StandardNamespaces + + assert StandardNamespaces.AUTH == "auth" + assert StandardNamespaces.TIMING == "timing" + assert StandardNamespaces.AUDIT == "audit" + assert StandardNamespaces.METRICS == "metrics" + assert StandardNamespaces.CUSTOM == "custom" From 89d20869170cf481c77974d1b00691d2d1637211 Mon Sep 17 00:00:00 2001 From: Sundar Raghavan Date: Thu, 30 Oct 2025 19:46:26 -0700 Subject: [PATCH 2/7] fix: imports in app.py --- src/bedrock_agentcore/runtime/app.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/bedrock_agentcore/runtime/app.py b/src/bedrock_agentcore/runtime/app.py index 2d51a46..7bfbbe3 100644 --- a/src/bedrock_agentcore/runtime/app.py +++ b/src/bedrock_agentcore/runtime/app.py @@ -22,6 +22,12 @@ from .context import AgentContext, BedrockAgentCoreContext, ProcessingContext, RequestContext from .models import ( + ACCESS_TOKEN_HEADER, + AUTHORIZATION_HEADER, + CUSTOM_HEADER_PREFIX, + OAUTH2_CALLBACK_URL_HEADER, + REQUEST_ID_HEADER, + SESSION_HEADER, TASK_ACTION_CLEAR_FORCED_STATUS, TASK_ACTION_FORCE_BUSY, TASK_ACTION_FORCE_HEALTHY, From 1ee621c659fa30cac7d6d6c1ee466baa401b9e4a Mon Sep 17 00:00:00 2001 From: Sundar Raghavan Date: Thu, 30 Oct 2025 19:48:27 -0700 Subject: [PATCH 3/7] fix: update constants in app.py --- src/bedrock_agentcore/runtime/app.py | 52 +++++++++++++++------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/src/bedrock_agentcore/runtime/app.py b/src/bedrock_agentcore/runtime/app.py index 7bfbbe3..6308d4a 100644 --- a/src/bedrock_agentcore/runtime/app.py +++ b/src/bedrock_agentcore/runtime/app.py @@ -277,53 +277,53 @@ def complete_async_task(self, task_id: int) -> bool: def _build_request_context(self, request) -> AgentContext: """Build request context from incoming request. - + Args: request: Starlette Request object - + Returns: AgentContext with both request and processing data """ try: # Generate request ID request_id = str(uuid.uuid4()) - + # Extract session ID from headers (case-insensitive) - session_id = request.headers.get("x-amzn-bedrock-agentcore-runtime-session-id") - + session_id = request.headers.get(SESSION_HEADER) + # Extract relevant headers (Authorization + Custom headers) request_headers = {} - - # Create a case-insensitive lookup + + # Handle case-insensitive header lookup for both Starlette Headers and plain dicts headers_lower = {} for key, value in request.headers.items(): headers_lower[key.lower()] = (key, value) # Store original key + value - + # Check for Authorization header (case-insensitive) - if "authorization" in headers_lower: - original_key, value = headers_lower["authorization"] + if AUTHORIZATION_HEADER.lower() in headers_lower: + original_key, value = headers_lower[AUTHORIZATION_HEADER.lower()] request_headers["Authorization"] = value # Normalize to "Authorization" - + # Extract custom headers with the specific prefix (case-insensitive) - custom_prefix = "x-amzn-bedrock-agentcore-runtime-custom-" + custom_prefix_lower = CUSTOM_HEADER_PREFIX.lower() for lower_key, (original_key, value) in headers_lower.items(): - if lower_key.startswith(custom_prefix): + if lower_key.startswith(custom_prefix_lower): request_headers[original_key] = value # Keep original case - + # Convert empty dict to None request_headers = request_headers if request_headers else None - + # Set in BedrockAgentCoreContext for global access BedrockAgentCoreContext.set_request_context(request_id, session_id) if request_headers: BedrockAgentCoreContext.set_request_headers(request_headers) - + # Check if middleware injected processing data processing_context = ProcessingContext() if hasattr(request, "state") and hasattr(request.state, "processing_data"): # Middleware has injected data processing_data = request.state.processing_data - + # Handle both flat and namespaced data structures if processing_data: # Check if it's already namespaced (dict of dicts) @@ -332,21 +332,25 @@ def _build_request_context(self, request) -> AgentContext: processing_context.middleware_data = processing_data else: # Flat structure - put in 'default' namespace - processing_context.middleware_data["default"] = processing_data - + processing_context.middleware_data['default'] = processing_data + BedrockAgentCoreContext.set_processing_context(processing_context) - + # Create and return AgentContext return AgentContext( - request=RequestContext(session_id=session_id, request_headers=request_headers), - processing=processing_context, + request=RequestContext( + session_id=session_id, + request_headers=request_headers + ), + processing=processing_context ) - + except Exception as e: self.logger.warning("Failed to build request context: %s: %s", type(e).__name__, e) # Return minimal context on error return AgentContext( - request=RequestContext(session_id=None, request_headers=None), processing=ProcessingContext() + request=RequestContext(session_id=None, request_headers=None), + processing=ProcessingContext() ) def _takes_context(self, handler: Callable) -> bool: From 753e5c040be04655244d77100b382635ed3fb543 Mon Sep 17 00:00:00 2001 From: Sundar Raghavan Date: Thu, 30 Oct 2025 19:50:56 -0700 Subject: [PATCH 4/7] fix: lint errors in app.py --- src/bedrock_agentcore/runtime/app.py | 43 ++++++++++++---------------- 1 file changed, 18 insertions(+), 25 deletions(-) diff --git a/src/bedrock_agentcore/runtime/app.py b/src/bedrock_agentcore/runtime/app.py index 6308d4a..389bc2a 100644 --- a/src/bedrock_agentcore/runtime/app.py +++ b/src/bedrock_agentcore/runtime/app.py @@ -22,11 +22,8 @@ from .context import AgentContext, BedrockAgentCoreContext, ProcessingContext, RequestContext from .models import ( - ACCESS_TOKEN_HEADER, AUTHORIZATION_HEADER, CUSTOM_HEADER_PREFIX, - OAUTH2_CALLBACK_URL_HEADER, - REQUEST_ID_HEADER, SESSION_HEADER, TASK_ACTION_CLEAR_FORCED_STATUS, TASK_ACTION_FORCE_BUSY, @@ -277,53 +274,53 @@ def complete_async_task(self, task_id: int) -> bool: def _build_request_context(self, request) -> AgentContext: """Build request context from incoming request. - + Args: request: Starlette Request object - + Returns: AgentContext with both request and processing data """ try: # Generate request ID request_id = str(uuid.uuid4()) - + # Extract session ID from headers (case-insensitive) session_id = request.headers.get(SESSION_HEADER) - + # Extract relevant headers (Authorization + Custom headers) request_headers = {} - + # Handle case-insensitive header lookup for both Starlette Headers and plain dicts headers_lower = {} for key, value in request.headers.items(): headers_lower[key.lower()] = (key, value) # Store original key + value - + # Check for Authorization header (case-insensitive) if AUTHORIZATION_HEADER.lower() in headers_lower: original_key, value = headers_lower[AUTHORIZATION_HEADER.lower()] request_headers["Authorization"] = value # Normalize to "Authorization" - + # Extract custom headers with the specific prefix (case-insensitive) custom_prefix_lower = CUSTOM_HEADER_PREFIX.lower() for lower_key, (original_key, value) in headers_lower.items(): if lower_key.startswith(custom_prefix_lower): request_headers[original_key] = value # Keep original case - + # Convert empty dict to None request_headers = request_headers if request_headers else None - + # Set in BedrockAgentCoreContext for global access BedrockAgentCoreContext.set_request_context(request_id, session_id) if request_headers: BedrockAgentCoreContext.set_request_headers(request_headers) - + # Check if middleware injected processing data processing_context = ProcessingContext() if hasattr(request, "state") and hasattr(request.state, "processing_data"): # Middleware has injected data processing_data = request.state.processing_data - + # Handle both flat and namespaced data structures if processing_data: # Check if it's already namespaced (dict of dicts) @@ -332,25 +329,21 @@ def _build_request_context(self, request) -> AgentContext: processing_context.middleware_data = processing_data else: # Flat structure - put in 'default' namespace - processing_context.middleware_data['default'] = processing_data - + processing_context.middleware_data["default"] = processing_data + BedrockAgentCoreContext.set_processing_context(processing_context) - + # Create and return AgentContext return AgentContext( - request=RequestContext( - session_id=session_id, - request_headers=request_headers - ), - processing=processing_context + request=RequestContext(session_id=session_id, request_headers=request_headers), + processing=processing_context, ) - + except Exception as e: self.logger.warning("Failed to build request context: %s: %s", type(e).__name__, e) # Return minimal context on error return AgentContext( - request=RequestContext(session_id=None, request_headers=None), - processing=ProcessingContext() + request=RequestContext(session_id=None, request_headers=None), processing=ProcessingContext() ) def _takes_context(self, handler: Callable) -> bool: From 865850901c293d3e96722abeb34b7a794f22f9a6 Mon Sep 17 00:00:00 2001 From: Sundar Raghavan Date: Thu, 30 Oct 2025 19:59:54 -0700 Subject: [PATCH 5/7] fix: lint errors in app.py --- src/bedrock_agentcore/runtime/app.py | 70 +++++++++++++++++----------- 1 file changed, 42 insertions(+), 28 deletions(-) diff --git a/src/bedrock_agentcore/runtime/app.py b/src/bedrock_agentcore/runtime/app.py index 389bc2a..ff87886 100644 --- a/src/bedrock_agentcore/runtime/app.py +++ b/src/bedrock_agentcore/runtime/app.py @@ -22,8 +22,11 @@ from .context import AgentContext, BedrockAgentCoreContext, ProcessingContext, RequestContext from .models import ( + ACCESS_TOKEN_HEADER, AUTHORIZATION_HEADER, CUSTOM_HEADER_PREFIX, + OAUTH2_CALLBACK_URL_HEADER, + REQUEST_ID_HEADER, SESSION_HEADER, TASK_ACTION_CLEAR_FORCED_STATUS, TASK_ACTION_FORCE_BUSY, @@ -273,48 +276,56 @@ def complete_async_task(self, task_id: int) -> bool: return False def _build_request_context(self, request) -> AgentContext: - """Build request context from incoming request. + """Build request context and setup all context variables. Args: request: Starlette Request object Returns: - AgentContext with both request and processing data + AgentContext containing both request and processing data """ try: - # Generate request ID - request_id = str(uuid.uuid4()) + headers = request.headers - # Extract session ID from headers (case-insensitive) - session_id = request.headers.get(SESSION_HEADER) + # Extract or generate request ID + request_id = headers.get(REQUEST_ID_HEADER) + if not request_id: + request_id = str(uuid.uuid4()) - # Extract relevant headers (Authorization + Custom headers) - request_headers = {} + # Extract session ID + session_id = headers.get(SESSION_HEADER) + BedrockAgentCoreContext.set_request_context(request_id, session_id) - # Handle case-insensitive header lookup for both Starlette Headers and plain dicts - headers_lower = {} - for key, value in request.headers.items(): - headers_lower[key.lower()] = (key, value) # Store original key + value + # Extract and set workload access token + agent_identity_token = headers.get(ACCESS_TOKEN_HEADER) + if agent_identity_token: + BedrockAgentCoreContext.set_workload_access_token(agent_identity_token) - # Check for Authorization header (case-insensitive) - if AUTHORIZATION_HEADER.lower() in headers_lower: - original_key, value = headers_lower[AUTHORIZATION_HEADER.lower()] - request_headers["Authorization"] = value # Normalize to "Authorization" + # Extract and set OAuth2 callback URL + oauth2_callback_url = headers.get(OAUTH2_CALLBACK_URL_HEADER) + if oauth2_callback_url: + BedrockAgentCoreContext.set_oauth2_callback_url(oauth2_callback_url) - # Extract custom headers with the specific prefix (case-insensitive) - custom_prefix_lower = CUSTOM_HEADER_PREFIX.lower() - for lower_key, (original_key, value) in headers_lower.items(): - if lower_key.startswith(custom_prefix_lower): - request_headers[original_key] = value # Keep original case + # Collect relevant request headers (Authorization + Custom headers) + request_headers = {} - # Convert empty dict to None - request_headers = request_headers if request_headers else None + # Add Authorization header if present + authorization_header = headers.get(AUTHORIZATION_HEADER) + if authorization_header is not None: + request_headers[AUTHORIZATION_HEADER] = authorization_header - # Set in BedrockAgentCoreContext for global access - BedrockAgentCoreContext.set_request_context(request_id, session_id) + # Add custom headers with the specified prefix + for header_name, header_value in headers.items(): + if header_name.lower().startswith(CUSTOM_HEADER_PREFIX.lower()): + request_headers[header_name] = header_value + + # Set in BedrockAgentCoreContext if any headers were found if request_headers: BedrockAgentCoreContext.set_request_headers(request_headers) + # Get the headers from context to pass to RequestContext + req_headers = BedrockAgentCoreContext.get_request_headers() + # Check if middleware injected processing data processing_context = ProcessingContext() if hasattr(request, "state") and hasattr(request.state, "processing_data"): @@ -333,15 +344,18 @@ def _build_request_context(self, request) -> AgentContext: BedrockAgentCoreContext.set_processing_context(processing_context) - # Create and return AgentContext + # Return AgentContext (combines request + processing) return AgentContext( - request=RequestContext(session_id=session_id, request_headers=request_headers), + request=RequestContext(session_id=session_id, request_headers=req_headers), processing=processing_context, ) except Exception as e: self.logger.warning("Failed to build request context: %s: %s", type(e).__name__, e) - # Return minimal context on error + request_id = str(uuid.uuid4()) + BedrockAgentCoreContext.set_request_context(request_id, None) + + # Return minimal AgentContext on error return AgentContext( request=RequestContext(session_id=None, request_headers=None), processing=ProcessingContext() ) From 8ea324810cd051bfb8ad74d3b4e65b11c5eb3f27 Mon Sep 17 00:00:00 2001 From: Sundar Raghavan Date: Fri, 31 Oct 2025 07:40:21 -0700 Subject: [PATCH 6/7] add integration test --- tests_integ/runtime/test_middleware.py | 275 +++++++++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 tests_integ/runtime/test_middleware.py diff --git a/tests_integ/runtime/test_middleware.py b/tests_integ/runtime/test_middleware.py new file mode 100644 index 0000000..f30fc25 --- /dev/null +++ b/tests_integ/runtime/test_middleware.py @@ -0,0 +1,275 @@ +"""Integration tests for middleware.""" + +import subprocess +import time +import requests +import pytest + + +@pytest.fixture(scope="module") +def middleware_server(tmp_path_factory): + """Start a real server with middleware.""" + tmp_dir = tmp_path_factory.mktemp("agent") + agent_file = tmp_dir / "test_agent.py" + + # Write agent with middleware + agent_file.write_text(""" +import time +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware +from bedrock_agentcore import BedrockAgentCoreApp + +class StandardNamespaces: + AUTH = "auth" + TIMING = "timing" + RATE_LIMIT = "rate_limit" + +class AuthMiddleware(BaseHTTPMiddleware): + NAMESPACE = StandardNamespaces.AUTH + + VALID_TOKENS = { + 'test-admin-token': {'user_id': 'admin-1', 'role': 'admin'}, + 'test-user-token': {'user_id': 'user-1', 'role': 'user'}, + } + + async def dispatch(self, request, call_next): + if not hasattr(request.state, 'processing_data'): + request.state.processing_data = {} + + auth_header = request.headers.get('authorization', '') + + if auth_header.startswith('Bearer '): + token = auth_header[7:] + user_info = self.VALID_TOKENS.get(token) + + if user_info: + request.state.processing_data[self.NAMESPACE] = { + 'authenticated': True, + **user_info + } + else: + request.state.processing_data[self.NAMESPACE] = { + 'authenticated': False, + 'error': 'Invalid token' + } + else: + request.state.processing_data[self.NAMESPACE] = { + 'authenticated': False, + 'error': 'No authorization header' + } + + return await call_next(request) + +class TimingMiddleware(BaseHTTPMiddleware): + NAMESPACE = StandardNamespaces.TIMING + + async def dispatch(self, request, call_next): + if not hasattr(request.state, 'processing_data'): + request.state.processing_data = {} + + start_time = time.time() + request.state.processing_data[self.NAMESPACE] = { + 'start_time': start_time + } + + response = await call_next(request) + + duration = time.time() - start_time + response.headers['X-Processing-Time'] = f"{duration:.3f}" + + return response + +class RateLimitMiddleware(BaseHTTPMiddleware): + NAMESPACE = StandardNamespaces.RATE_LIMIT + + def __init__(self, app, max_requests=5): + super().__init__(app) + self.max_requests = max_requests + self.request_counts = {} + + async def dispatch(self, request, call_next): + if not hasattr(request.state, 'processing_data'): + request.state.processing_data = {} + + auth_data = request.state.processing_data.get(StandardNamespaces.AUTH, {}) + user_id = auth_data.get('user_id', 'anonymous') + + current_time = time.time() + + self.request_counts = { + uid: [t for t in times if current_time - t < 60] + for uid, times in self.request_counts.items() + } + + user_requests = self.request_counts.get(user_id, []) + + if len(user_requests) >= self.max_requests: + request.state.processing_data[self.NAMESPACE] = { + 'allowed': False, + 'limit': self.max_requests + } + else: + user_requests.append(current_time) + self.request_counts[user_id] = user_requests + + request.state.processing_data[self.NAMESPACE] = { + 'allowed': True, + 'remaining': self.max_requests - len(user_requests) + } + + return await call_next(request) + +app = BedrockAgentCoreApp( + middleware=[ + Middleware(TimingMiddleware), + Middleware(AuthMiddleware), + Middleware(RateLimitMiddleware, max_requests=5), + ] +) + +@app.entrypoint +def handler(payload, context): + authenticated = context.processing.get(StandardNamespaces.AUTH, 'authenticated', False) + + if not authenticated: + error = context.processing.get(StandardNamespaces.AUTH, 'error', 'Unknown') + return {"error": "Authentication required", "details": error} + + rate_allowed = context.processing.get(StandardNamespaces.RATE_LIMIT, 'allowed', True) + + if not rate_allowed: + limit = context.processing.get(StandardNamespaces.RATE_LIMIT, 'limit') + return {"error": "Rate limit exceeded", "limit": limit} + + user_id = context.processing.get(StandardNamespaces.AUTH, 'user_id') + role = context.processing.get(StandardNamespaces.AUTH, 'role') + remaining = context.processing.get(StandardNamespaces.RATE_LIMIT, 'remaining') + + if payload.get('delay'): + time.sleep(payload['delay']) + + return { + "success": True, + "user_id": user_id, + "role": role, + "rate_limit_remaining": remaining, + "message": payload.get('message', 'Hello from middleware!') + } + +if __name__ == "__main__": + app.run() +""") + + # Start server + proc = subprocess.Popen( + ["python", str(agent_file)], + cwd=str(tmp_dir), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE + ) + + # Wait for server to start + time.sleep(3) + + yield "http://127.0.0.1:8080" + + # Cleanup + proc.terminate() + proc.wait(timeout=5) + + +def test_server_ping(middleware_server): + """Test server is running.""" + response = requests.get(f"{middleware_server}/ping") + assert response.status_code == 200 + + +def test_authenticated_request(middleware_server): + """Test authenticated request succeeds.""" + response = requests.post( + f"{middleware_server}/invocations", + json={"message": "Hello middleware!"}, + headers={"Authorization": "Bearer test-admin-token"} + ) + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["user_id"] == "admin-1" + assert data["role"] == "admin" + assert "rate_limit_remaining" in data + + +def test_unauthenticated_request(middleware_server): + """Test unauthenticated request is rejected.""" + response = requests.post( + f"{middleware_server}/invocations", + json={"message": "This should fail"} + ) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert "Authentication required" in data["error"] + + +def test_invalid_token(middleware_server): + """Test invalid token is rejected.""" + response = requests.post( + f"{middleware_server}/invocations", + json={"message": "Invalid token"}, + headers={"Authorization": "Bearer invalid-token"} + ) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert "Invalid token" in data["details"] + + +def test_timing_header(middleware_server): + """Test timing middleware adds header.""" + response = requests.post( + f"{middleware_server}/invocations", + json={"message": "Timing test", "delay": 0.1}, + headers={"Authorization": "Bearer test-user-token"} + ) + assert response.status_code == 200 + assert "X-Processing-Time" in response.headers + processing_time = float(response.headers["X-Processing-Time"]) + assert processing_time >= 0.1 + + +def test_rate_limiting(middleware_server): + """Test rate limiting with multiple requests.""" + # Use test-user-token which hasn't hit rate limits yet + headers = {"Authorization": "Bearer test-user-token"} + + # Track how many are remaining from first request + response = requests.post( + f"{middleware_server}/invocations", + json={"message": "Request 1"}, + headers=headers + ) + data = response.json() + assert data["success"] is True + + # Get starting count + first_remaining = data["rate_limit_remaining"] + + # Make more requests until we hit the limit + for i in range(first_remaining): + response = requests.post( + f"{middleware_server}/invocations", + json={"message": f"Request {i+2}"}, + headers=headers + ) + data = response.json() + assert data["success"] is True + + # Next request should be rate limited + response = requests.post( + f"{middleware_server}/invocations", + json={"message": "Rate limited"}, + headers=headers + ) + data = response.json() + assert "error" in data + assert "Rate limit exceeded" in data["error"] From ee5b396f200fb5a70bdfb22788e3e621a8c9eebc Mon Sep 17 00:00:00 2001 From: Sundar Raghavan Date: Fri, 31 Oct 2025 08:22:39 -0700 Subject: [PATCH 7/7] fix: lint errors --- tests_integ/runtime/test_middleware.py | 105 +++++++++++-------------- 1 file changed, 45 insertions(+), 60 deletions(-) diff --git a/tests_integ/runtime/test_middleware.py b/tests_integ/runtime/test_middleware.py index f30fc25..ebd140f 100644 --- a/tests_integ/runtime/test_middleware.py +++ b/tests_integ/runtime/test_middleware.py @@ -2,8 +2,9 @@ import subprocess import time -import requests + import pytest +import requests @pytest.fixture(scope="module") @@ -11,7 +12,7 @@ def middleware_server(tmp_path_factory): """Start a real server with middleware.""" tmp_dir = tmp_path_factory.mktemp("agent") agent_file = tmp_dir / "test_agent.py" - + # Write agent with middleware agent_file.write_text(""" import time @@ -26,22 +27,22 @@ class StandardNamespaces: class AuthMiddleware(BaseHTTPMiddleware): NAMESPACE = StandardNamespaces.AUTH - + VALID_TOKENS = { 'test-admin-token': {'user_id': 'admin-1', 'role': 'admin'}, 'test-user-token': {'user_id': 'user-1', 'role': 'user'}, } - + async def dispatch(self, request, call_next): if not hasattr(request.state, 'processing_data'): request.state.processing_data = {} - + auth_header = request.headers.get('authorization', '') - + if auth_header.startswith('Bearer '): token = auth_header[7:] user_info = self.VALID_TOKENS.get(token) - + if user_info: request.state.processing_data[self.NAMESPACE] = { 'authenticated': True, @@ -57,52 +58,52 @@ async def dispatch(self, request, call_next): 'authenticated': False, 'error': 'No authorization header' } - + return await call_next(request) class TimingMiddleware(BaseHTTPMiddleware): NAMESPACE = StandardNamespaces.TIMING - + async def dispatch(self, request, call_next): if not hasattr(request.state, 'processing_data'): request.state.processing_data = {} - + start_time = time.time() request.state.processing_data[self.NAMESPACE] = { 'start_time': start_time } - + response = await call_next(request) - + duration = time.time() - start_time response.headers['X-Processing-Time'] = f"{duration:.3f}" - + return response class RateLimitMiddleware(BaseHTTPMiddleware): NAMESPACE = StandardNamespaces.RATE_LIMIT - + def __init__(self, app, max_requests=5): super().__init__(app) self.max_requests = max_requests self.request_counts = {} - + async def dispatch(self, request, call_next): if not hasattr(request.state, 'processing_data'): request.state.processing_data = {} - + auth_data = request.state.processing_data.get(StandardNamespaces.AUTH, {}) user_id = auth_data.get('user_id', 'anonymous') - + current_time = time.time() - + self.request_counts = { uid: [t for t in times if current_time - t < 60] for uid, times in self.request_counts.items() } - + user_requests = self.request_counts.get(user_id, []) - + if len(user_requests) >= self.max_requests: request.state.processing_data[self.NAMESPACE] = { 'allowed': False, @@ -111,12 +112,12 @@ async def dispatch(self, request, call_next): else: user_requests.append(current_time) self.request_counts[user_id] = user_requests - + request.state.processing_data[self.NAMESPACE] = { 'allowed': True, 'remaining': self.max_requests - len(user_requests) } - + return await call_next(request) app = BedrockAgentCoreApp( @@ -130,24 +131,24 @@ async def dispatch(self, request, call_next): @app.entrypoint def handler(payload, context): authenticated = context.processing.get(StandardNamespaces.AUTH, 'authenticated', False) - + if not authenticated: error = context.processing.get(StandardNamespaces.AUTH, 'error', 'Unknown') return {"error": "Authentication required", "details": error} - + rate_allowed = context.processing.get(StandardNamespaces.RATE_LIMIT, 'allowed', True) - + if not rate_allowed: limit = context.processing.get(StandardNamespaces.RATE_LIMIT, 'limit') return {"error": "Rate limit exceeded", "limit": limit} - + user_id = context.processing.get(StandardNamespaces.AUTH, 'user_id') role = context.processing.get(StandardNamespaces.AUTH, 'role') remaining = context.processing.get(StandardNamespaces.RATE_LIMIT, 'remaining') - + if payload.get('delay'): time.sleep(payload['delay']) - + return { "success": True, "user_id": user_id, @@ -159,20 +160,17 @@ def handler(payload, context): if __name__ == "__main__": app.run() """) - + # Start server proc = subprocess.Popen( - ["python", str(agent_file)], - cwd=str(tmp_dir), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE + ["python", str(agent_file)], cwd=str(tmp_dir), stdout=subprocess.PIPE, stderr=subprocess.PIPE ) - + # Wait for server to start time.sleep(3) - + yield "http://127.0.0.1:8080" - + # Cleanup proc.terminate() proc.wait(timeout=5) @@ -189,7 +187,7 @@ def test_authenticated_request(middleware_server): response = requests.post( f"{middleware_server}/invocations", json={"message": "Hello middleware!"}, - headers={"Authorization": "Bearer test-admin-token"} + headers={"Authorization": "Bearer test-admin-token"}, ) assert response.status_code == 200 data = response.json() @@ -201,10 +199,7 @@ def test_authenticated_request(middleware_server): def test_unauthenticated_request(middleware_server): """Test unauthenticated request is rejected.""" - response = requests.post( - f"{middleware_server}/invocations", - json={"message": "This should fail"} - ) + response = requests.post(f"{middleware_server}/invocations", json={"message": "This should fail"}) assert response.status_code == 200 data = response.json() assert "error" in data @@ -216,7 +211,7 @@ def test_invalid_token(middleware_server): response = requests.post( f"{middleware_server}/invocations", json={"message": "Invalid token"}, - headers={"Authorization": "Bearer invalid-token"} + headers={"Authorization": "Bearer invalid-token"}, ) assert response.status_code == 200 data = response.json() @@ -229,7 +224,7 @@ def test_timing_header(middleware_server): response = requests.post( f"{middleware_server}/invocations", json={"message": "Timing test", "delay": 0.1}, - headers={"Authorization": "Bearer test-user-token"} + headers={"Authorization": "Bearer test-user-token"}, ) assert response.status_code == 200 assert "X-Processing-Time" in response.headers @@ -241,35 +236,25 @@ def test_rate_limiting(middleware_server): """Test rate limiting with multiple requests.""" # Use test-user-token which hasn't hit rate limits yet headers = {"Authorization": "Bearer test-user-token"} - + # Track how many are remaining from first request - response = requests.post( - f"{middleware_server}/invocations", - json={"message": "Request 1"}, - headers=headers - ) + response = requests.post(f"{middleware_server}/invocations", json={"message": "Request 1"}, headers=headers) data = response.json() assert data["success"] is True - + # Get starting count first_remaining = data["rate_limit_remaining"] - + # Make more requests until we hit the limit for i in range(first_remaining): response = requests.post( - f"{middleware_server}/invocations", - json={"message": f"Request {i+2}"}, - headers=headers + f"{middleware_server}/invocations", json={"message": f"Request {i + 2}"}, headers=headers ) data = response.json() assert data["success"] is True - + # Next request should be rate limited - response = requests.post( - f"{middleware_server}/invocations", - json={"message": "Rate limited"}, - headers=headers - ) + response = requests.post(f"{middleware_server}/invocations", json={"message": "Rate limited"}, headers=headers) data = response.json() assert "error" in data assert "Rate limit exceeded" in data["error"]