diff --git a/aidefense/runtime/inspection_client.py b/aidefense/runtime/inspection_client.py index 844211d..846a1ae 100644 --- a/aidefense/runtime/inspection_client.py +++ b/aidefense/runtime/inspection_client.py @@ -118,8 +118,42 @@ def _inspect(self, *args, **kwargs): """ raise NotImplementedError("Subclasses must implement _inspect.") + @staticmethod + def _parse_rule(rule_data: Dict[str, Any]) -> Rule: + """ + Parse a single rule from API response data into a Rule object. + + Args: + rule_data (Dict[str, Any]): The rule data from the API response. + + Returns: + Rule: The parsed Rule object. + """ + # Try to convert to enum, keep original string if not in enum + rule_name = rule_data["rule_name"] + try: + rule_name = RuleName(rule_data["rule_name"]) + except ValueError: + # Keep the original string for custom rule names + pass + + # Try to convert to enum, keep original string if not in enum + classification = rule_data.get("classification") + try: + classification = Classification(rule_data["classification"]) + except ValueError: + # Keep the original string for custom classifications + pass + + return Rule( + rule_name=rule_name, + entity_types=rule_data.get("entity_types"), + rule_id=rule_data.get("rule_id"), + classification=classification, + ) + def _parse_inspect_response( - self, response_data: Dict[str, Any] + self, response_data: Dict[str, Any] ) -> "InspectResponse": """ Parse API response (chat or http inspect) into an InspectResponse object. @@ -149,6 +183,16 @@ def _parse_inspect_response( "classification": "SECURITY_VIOLATION" } ], + "processed_rules": [ + { + "rule_name": "Prompt Injection", + "rule_id": 0, + "entity_types": [ + "" + ], + "classification": "SECURITY_VIOLATION" + } + ], "attack_technique": "NONE_ATTACK_TECHNIQUE", "explanation": "", "client_transaction_id": "", @@ -170,6 +214,14 @@ def _parse_inspect_response( classification=Classification.SECURITY_VIOLATION ) ], + processed_rules=[ + Rule( + rule_name="Prompt Injection", # Note: This will remain a string since it's not in RuleName enum + rule_id=0, + entity_types=[""], + classification=Classification.SECURITY_VIOLATION + ) + ], attack_technique="NONE_ATTACK_TECHNIQUE", explanation="", client_transaction_id="", @@ -191,30 +243,11 @@ def _parse_inspect_response( # Log invalid classification but don't add it self.config.logger.warning(f"Invalid classification type: {cls}") # Parse rules if present - rules = [] - for rule_data in response_data.get("rules", []): - # Try to convert to enum, keep original string if not in enum - rule_name = rule_data["rule_name"] - try: - rule_name = RuleName(rule_data["rule_name"]) - except ValueError: - # Keep the original string for custom rule names - pass - # Try to convert to enum, keep original string if not in enum - classification = rule_data.get("classification") - try: - classification = Classification(rule_data["classification"]) - except ValueError: - # Keep the original string for custom classifications - pass - rules.append( - Rule( - rule_name=rule_name, - entity_types=rule_data.get("entity_types"), - rule_id=rule_data.get("rule_id"), - classification=classification, - ) - ) + rules = [self._parse_rule(rule_data) for rule_data in response_data.get("rules", [])] + + # Parse processed rules if present + processed_rules = [self._parse_rule(rule_data) for rule_data in response_data.get("processed_rules", [])] + # Parse severity if present severity = None try: @@ -227,6 +260,7 @@ def _parse_inspect_response( is_safe=response_data.get("is_safe", True), severity=severity, rules=rules or None, + processed_rules=processed_rules or None, attack_technique=response_data.get("attack_technique"), explanation=response_data.get("explanation"), client_transaction_id=response_data.get("client_transaction_id"), diff --git a/aidefense/runtime/models.py b/aidefense/runtime/models.py index 344a4c8..bb02901 100644 --- a/aidefense/runtime/models.py +++ b/aidefense/runtime/models.py @@ -162,6 +162,7 @@ class InspectResponse: explanation (Optional[str]): Human-readable explanation of the inspection result. client_transaction_id (Optional[str]): Unique client-provided transaction ID for tracing. event_id (Optional[str]): Unique event ID assigned by the backend. + processed_rules (Optional[List[Rule]]): List of rules applied for inspection. """ classifications: List[Classification] @@ -172,3 +173,4 @@ class InspectResponse: explanation: Optional[str] = None client_transaction_id: Optional[str] = None event_id: Optional[str] = None + processed_rules: Optional[List[Rule]] = None