Skip to content

Commit d74fc46

Browse files
f-wrightAuto-format Bot
andauthored
Bounding box mode detector creation (#355)
This PR adds a function to the experimental API to create bounding box mode detectors. This mode is not currently enabled for general usage, so the tests are disabled for now. --------- Co-authored-by: Auto-format Bot <[email protected]>
1 parent 0d8f39e commit d74fc46

File tree

3 files changed

+133
-4
lines changed

3 files changed

+133
-4
lines changed

src/groundlight/experimental_api.py

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from groundlight_openapi_client.api.image_queries_api import ImageQueriesApi
2121
from groundlight_openapi_client.api.notes_api import NotesApi
2222
from groundlight_openapi_client.model.action_request import ActionRequest
23+
from groundlight_openapi_client.model.bounding_box_mode_configuration import BoundingBoxModeConfiguration
2324
from groundlight_openapi_client.model.channel_enum import ChannelEnum
2425
from groundlight_openapi_client.model.condition_request import ConditionRequest
2526
from groundlight_openapi_client.model.count_mode_configuration import CountModeConfiguration
@@ -902,10 +903,12 @@ def create_counting_detector( # noqa: PLR0913 # pylint: disable=too-many-argume
902903
metadata=metadata,
903904
)
904905
detector_creation_input.mode = ModeEnum.COUNT
905-
# TODO: pull the BE defined default
906+
906907
if max_count is None:
907-
max_count = 10
908-
mode_config = CountModeConfiguration(max_count=max_count, class_name=class_name)
908+
mode_config = CountModeConfiguration(class_name=class_name)
909+
else:
910+
mode_config = CountModeConfiguration(max_count=max_count, class_name=class_name)
911+
909912
detector_creation_input.mode_configuration = mode_config
910913
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
911914
return Detector.parse_obj(obj.to_dict())
@@ -974,6 +977,81 @@ def create_multiclass_detector( # noqa: PLR0913 # pylint: disable=too-many-argu
974977
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
975978
return Detector.parse_obj(obj.to_dict())
976979

980+
def create_bounding_box_detector( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-locals
981+
self,
982+
name: str,
983+
query: str,
984+
class_name: str,
985+
*,
986+
max_num_bboxes: Optional[int] = None,
987+
group_name: Optional[str] = None,
988+
confidence_threshold: Optional[float] = None,
989+
patience_time: Optional[float] = None,
990+
pipeline_config: Optional[str] = None,
991+
metadata: Union[dict, str, None] = None,
992+
) -> Detector:
993+
"""
994+
Creates a bounding box detector that can detect objects in images up to a specified maximum number of bounding
995+
boxes.
996+
997+
**Example usage**::
998+
999+
gl = ExperimentalApi()
1000+
1001+
# Create a detector that counts people up to 5
1002+
detector = gl.create_bounding_box_detector(
1003+
name="people_counter",
1004+
query="Draw a bounding box around each person in the image",
1005+
class_name="person",
1006+
max_num_bboxes=5,
1007+
confidence_threshold=0.9,
1008+
patience_time=30.0
1009+
)
1010+
1011+
# Use the detector to find people in an image
1012+
image_query = gl.ask_ml(detector, "path/to/image.jpg")
1013+
print(f"Confidence: {image_query.result.confidence}")
1014+
print(f"Bounding boxes: {image_query.result.rois}")
1015+
1016+
:param name: A short, descriptive name for the detector.
1017+
:param query: A question about the object to detect in the image.
1018+
:param class_name: The class name of the object to detect.
1019+
:param max_num_bboxes: Maximum number of bounding boxes to detect (default: 10)
1020+
:param group_name: Optional name of a group to organize related detectors together.
1021+
:param confidence_threshold: A value that sets the minimum confidence level required for the ML model's
1022+
predictions. If confidence is below this threshold, the query may be sent for human review.
1023+
:param patience_time: The maximum time in seconds that Groundlight will attempt to generate a
1024+
confident prediction before falling back to human review. Defaults to 30 seconds.
1025+
:param pipeline_config: Advanced usage only. Configuration string needed to instantiate a specific
1026+
prediction pipeline for this detector.
1027+
:param metadata: A dictionary or JSON string containing custom key/value pairs to associate with
1028+
the detector (limited to 1KB). This metadata can be used to store additional
1029+
information like location, purpose, or related system IDs. You can retrieve this
1030+
metadata later by calling `get_detector()`.
1031+
1032+
:return: The created Detector object
1033+
"""
1034+
1035+
detector_creation_input = self._prep_create_detector(
1036+
name=name,
1037+
query=query,
1038+
group_name=group_name,
1039+
confidence_threshold=confidence_threshold,
1040+
patience_time=patience_time,
1041+
pipeline_config=pipeline_config,
1042+
metadata=metadata,
1043+
)
1044+
detector_creation_input.mode = ModeEnum.BOUNDING_BOX
1045+
1046+
if max_num_bboxes is None:
1047+
mode_config = BoundingBoxModeConfiguration(class_name=class_name)
1048+
else:
1049+
mode_config = BoundingBoxModeConfiguration(max_num_bboxes=max_num_bboxes, class_name=class_name)
1050+
1051+
detector_creation_input.mode_configuration = mode_config
1052+
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
1053+
return Detector.parse_obj(obj.to_dict())
1054+
9771055
def _download_mlbinary_url(self, detector: Union[str, Detector]) -> EdgeModelInfo:
9781056
"""
9791057
Gets a temporary presigned URL to download the model binaries for the given detector, along

test/integration/test_groundlight.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ksuid import KsuidMs
1818
from model import (
1919
BinaryClassificationResult,
20+
BoundingBoxResult,
2021
CountingResult,
2122
Detector,
2223
ImageQuery,
@@ -35,6 +36,7 @@ def is_valid_display_result(result: Any) -> bool:
3536
not isinstance(result, BinaryClassificationResult)
3637
and not isinstance(result, CountingResult)
3738
and not isinstance(result, MultiClassificationResult)
39+
and not isinstance(result, BoundingBoxResult)
3840
):
3941
return False
4042

test/unit/test_experimental.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import time
2-
from datetime import datetime
2+
from datetime import datetime, timezone
33

44
import pytest
55
from groundlight import ExperimentalApi
@@ -145,3 +145,52 @@ def test_multiclass_detector(gl_experimental: ExperimentalApi):
145145
mc_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
146146
assert mc_iq.result.label is not None
147147
assert mc_iq.result.label in class_names
148+
149+
150+
@pytest.mark.skip(
151+
reason=(
152+
"General users currently currently can't use bounding box detectors. If you have questions, reach out"
153+
" to Groundlight support, or upgrade your plan."
154+
)
155+
)
156+
def test_bounding_box_detector(gl_experimental: ExperimentalApi):
157+
"""
158+
Verify that we can create and submit to a bounding box detector
159+
"""
160+
name = f"Test {datetime.now(timezone.utc)}"
161+
created_detector = gl_experimental.create_bounding_box_detector(
162+
name, "Draw a bounding box around each dog in the image", "dog"
163+
)
164+
assert created_detector is not None
165+
bbox_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
166+
assert bbox_iq.result.label is not None
167+
assert bbox_iq.rois is not None
168+
169+
170+
@pytest.mark.skip(
171+
reason=(
172+
"General users currently currently can't use bounding box detectors. If you have questions, reach out"
173+
" to Groundlight support, or upgrade your plan."
174+
)
175+
)
176+
def test_bounding_box_detector_async(gl_experimental: ExperimentalApi):
177+
"""
178+
Verify that we can create and submit to a bounding box detector with ask_async
179+
"""
180+
name = f"Test {datetime.now(timezone.utc)}"
181+
created_detector = gl_experimental.create_bounding_box_detector(
182+
name, "Draw a bounding box around each dog in the image", "dog"
183+
)
184+
assert created_detector is not None
185+
async_iq = gl_experimental.ask_async(created_detector, "test/assets/dog.jpeg")
186+
187+
# attempting to access fields within the result should raise an exception
188+
with pytest.raises(AttributeError):
189+
_ = async_iq.result.label # type: ignore
190+
with pytest.raises(AttributeError):
191+
_ = async_iq.result.confidence # type: ignore
192+
193+
time.sleep(5)
194+
# you should be able to get a "real" result by retrieving an updated image query object from the server
195+
_image_query = gl_experimental.get_image_query(id=async_iq.id)
196+
assert _image_query.result is not None

0 commit comments

Comments
 (0)