Skip to content

Add BenchmarkEvaluator with basic precision/recall computation #1870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions BenchmarkEvaluator_Demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "70868bca",
"metadata": {},
"source": [
"# 🎯 BenchmarkEvaluator Demo\n",
"\n",
"This notebook demonstrates how to use `BenchmarkEvaluator` to compute precision/recall metrics for object detection tasks."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ee3b103",
"metadata": {},
"outputs": [],
"source": [
"import cv2\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from supervision.detection.core import Detections\n",
"from supervision.metrics.benchmark import BenchmarkEvaluator"
]
},
{
"cell_type": "markdown",
"id": "f806eff5",
"metadata": {},
"source": [
"## Step 1: Create Ground Truth and Predictions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65183606",
"metadata": {},
"outputs": [],
"source": [
"# Ground truth with 2 boxes\n",
"gt = Detections(\n",
" xyxy=np.array([[10, 10, 100, 100], [150, 150, 300, 300]]), class_id=np.array([0, 1])\n",
")\n",
"\n",
"# Predictions: One perfect match, one wrong class\n",
"pred = Detections(\n",
" xyxy=np.array([[10, 10, 100, 100], [150, 150, 300, 300]]), class_id=np.array([0, 2])\n",
")"
]
},
{
"cell_type": "markdown",
"id": "529f0ef0",
"metadata": {},
"source": [
"## Step 2: Run BenchmarkEvaluator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5615d704",
"metadata": {},
"outputs": [],
"source": [
"evaluator = BenchmarkEvaluator(ground_truth=gt, predictions=pred)\n",
"metrics = evaluator.compute_precision_recall()\n",
"print(\"Precision:\", metrics[\"precision\"])\n",
"print(\"Recall:\", metrics[\"recall\"])"
]
},
{
"cell_type": "markdown",
"id": "9ab6f923",
"metadata": {},
"source": [
"## Step 3: Per-Class Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dde2bc49",
"metadata": {},
"outputs": [],
"source": [
"per_class = evaluator.compute_precision_recall_per_class()\n",
"for cls, metric in per_class.items():\n",
" print(\n",
" f\"Class {cls} - Precision: {metric['precision']:.2f}, Recall: {metric['recall']:.2f}\"\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "dfa1f1e5",
"metadata": {},
"source": [
"## Step 4: Visualize Bounding Boxes"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6a6ce9d",
"metadata": {},
"outputs": [],
"source": [
"def draw_boxes(image, detections, color, label):\n",
" for box, cls in zip(detections.xyxy, detections.class_id):\n",
" x1, y1, x2, y2 = box.astype(int)\n",
" cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)\n",
" cv2.putText(\n",
" image,\n",
" f\"{label}:{cls}\",\n",
" (x1, y1 - 10),\n",
" cv2.FONT_HERSHEY_SIMPLEX,\n",
" 0.5,\n",
" color,\n",
" 2,\n",
" )\n",
"\n",
"\n",
"canvas = np.ones((350, 350, 3), dtype=np.uint8) * 255\n",
"draw_boxes(canvas, gt, (0, 255, 0), \"GT\")\n",
"draw_boxes(canvas, pred, (0, 0, 255), \"Pred\")\n",
"\n",
"plt.imshow(canvas[..., ::-1])\n",
"plt.title(\"Ground Truth (Green) vs Prediction (Red)\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "7b3d6112",
"metadata": {},
"source": [
"🎉 That's it! You've run a complete object detection benchmark with precision/recall metrics and visualization."
]
}
],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 5
}
40 changes: 40 additions & 0 deletions supervision/metrics/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# supervision/metrics/benchmark.py

from typing import Dict, Optional

from supervision.detection.core import Detections


class BenchmarkEvaluator:
def __init__(
self,
ground_truth: Detections,
predictions: Detections,
class_map: Optional[Dict[str, str]] = None,
iou_threshold: float = 0.5,
):
self.ground_truth = ground_truth
self.predictions = predictions
self.class_map = class_map or {}
self.iou_threshold = iou_threshold

def compute_precision_recall(self) -> Dict[str, float]:
"""
Compute basic precision and recall metrics.
For demo purposes — you will expand this.
"""
# TODO: Add class alignment, matching using IoU
tp = len(self.predictions.xyxy) # Placeholder
fp = 0
fn = len(self.ground_truth.xyxy) - tp
Comment on lines +26 to +29
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic here is incomplete, please add the correct logic to compute precision and recall.


precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0

return {"precision": precision, "recall": recall}

def summary(self) -> None:
metrics = self.compute_precision_recall()
print("Benchmark Summary:")
for k, v in metrics.items():
print(f"{k}: {v:.4f}")
15 changes: 15 additions & 0 deletions tests/metrics/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np

from supervision.detection.core import Detections
from supervision.metrics.benchmark import BenchmarkEvaluator


def test_basic_precision_recall():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This too seems like a placeholder test; please proceed with the implementation and add comprehensive unit tests.

gt = Detections(xyxy=np.array([[0, 0, 100, 100]]), class_id=np.array([0]))
pred = Detections(xyxy=np.array([[0, 0, 100, 100]]), class_id=np.array([0]))

evaluator = BenchmarkEvaluator(ground_truth=gt, predictions=pred)
metrics = evaluator.compute_precision_recall()

assert metrics["precision"] == 1.0
assert metrics["recall"] == 1.0