Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions aidefense/runtime/inspection_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,42 @@ def _inspect(self, *args, **kwargs):
"""
raise NotImplementedError("Subclasses must implement _inspect.")

@staticmethod
def _parse_rule(rule_data: Dict[str, Any]) -> Rule:
"""
Parse a single rule from API response data into a Rule object.

Args:
rule_data (Dict[str, Any]): The rule data from the API response.

Returns:
Rule: The parsed Rule object.
"""
# Try to convert to enum, keep original string if not in enum
rule_name = rule_data["rule_name"]
try:
rule_name = RuleName(rule_data["rule_name"])
except ValueError:
# Keep the original string for custom rule names
pass

# Try to convert to enum, keep original string if not in enum
classification = rule_data.get("classification")
try:
classification = Classification(rule_data["classification"])
except ValueError:
# Keep the original string for custom classifications
pass

return Rule(
rule_name=rule_name,
entity_types=rule_data.get("entity_types"),
rule_id=rule_data.get("rule_id"),
classification=classification,
)

def _parse_inspect_response(
self, response_data: Dict[str, Any]
self, response_data: Dict[str, Any]
) -> "InspectResponse":
"""
Parse API response (chat or http inspect) into an InspectResponse object.
Expand Down Expand Up @@ -149,6 +183,16 @@ def _parse_inspect_response(
"classification": "SECURITY_VIOLATION"
}
],
"processed_rules": [
{
"rule_name": "Prompt Injection",
"rule_id": 0,
"entity_types": [
""
],
"classification": "SECURITY_VIOLATION"
}
],
"attack_technique": "NONE_ATTACK_TECHNIQUE",
"explanation": "",
"client_transaction_id": "",
Expand All @@ -170,6 +214,14 @@ def _parse_inspect_response(
classification=Classification.SECURITY_VIOLATION
)
],
processed_rules=[
Rule(
rule_name="Prompt Injection", # Note: This will remain a string since it's not in RuleName enum
rule_id=0,
entity_types=[""],
classification=Classification.SECURITY_VIOLATION
)
],
attack_technique="NONE_ATTACK_TECHNIQUE",
explanation="",
client_transaction_id="",
Expand All @@ -191,30 +243,11 @@ def _parse_inspect_response(
# Log invalid classification but don't add it
self.config.logger.warning(f"Invalid classification type: {cls}")
# Parse rules if present
rules = []
for rule_data in response_data.get("rules", []):
# Try to convert to enum, keep original string if not in enum
rule_name = rule_data["rule_name"]
try:
rule_name = RuleName(rule_data["rule_name"])
except ValueError:
# Keep the original string for custom rule names
pass
# Try to convert to enum, keep original string if not in enum
classification = rule_data.get("classification")
try:
classification = Classification(rule_data["classification"])
except ValueError:
# Keep the original string for custom classifications
pass
rules.append(
Rule(
rule_name=rule_name,
entity_types=rule_data.get("entity_types"),
rule_id=rule_data.get("rule_id"),
classification=classification,
)
)
rules = [self._parse_rule(rule_data) for rule_data in response_data.get("rules", [])]

# Parse processed rules if present
processed_rules = [self._parse_rule(rule_data) for rule_data in response_data.get("processed_rules", [])]

# Parse severity if present
severity = None
try:
Expand All @@ -227,6 +260,7 @@ def _parse_inspect_response(
is_safe=response_data.get("is_safe", True),
severity=severity,
rules=rules or None,
processed_rules=processed_rules or None,
attack_technique=response_data.get("attack_technique"),
explanation=response_data.get("explanation"),
client_transaction_id=response_data.get("client_transaction_id"),
Expand Down
2 changes: 2 additions & 0 deletions aidefense/runtime/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class InspectResponse:
explanation (Optional[str]): Human-readable explanation of the inspection result.
client_transaction_id (Optional[str]): Unique client-provided transaction ID for tracing.
event_id (Optional[str]): Unique event ID assigned by the backend.
processed_rules (Optional[List[Rule]]): List of rules applied for inspection.
"""

classifications: List[Classification]
Expand All @@ -172,3 +173,4 @@ class InspectResponse:
explanation: Optional[str] = None
client_transaction_id: Optional[str] = None
event_id: Optional[str] = None
processed_rules: Optional[List[Rule]] = None