1010from openapi_client .api .image_queries_api import ImageQueriesApi
1111from openapi_client .model .detector_creation_input import DetectorCreationInput
1212
13- from groundlight .binary_labels import convert_display_label_to_internal
13+ from groundlight .binary_labels import Label , convert_display_label_to_internal , convert_internal_label_to_display
1414from groundlight .config import API_TOKEN_VARIABLE_NAME , API_TOKEN_WEB_URL
1515from groundlight .images import parse_supported_image_types
1616from groundlight .internalapi import GroundlightApiClient , NotFoundError , sanitize_endpoint_url
@@ -71,6 +71,15 @@ def __init__(self, endpoint: Optional[str] = None, api_token: Optional[str] = No
7171 self .detectors_api = DetectorsApi (self .api_client )
7272 self .image_queries_api = ImageQueriesApi (self .api_client )
7373
74+ @classmethod
75+ def _post_process_image_query (cls , iq : ImageQuery ) -> ImageQuery :
76+ """Post-process the image query so we don't use confusing internal labels.
77+
78+ TODO: Get rid of this once we clean up the mapping logic server-side.
79+ """
80+ iq .result .label = convert_internal_label_to_display (iq , iq .result .label )
81+ return iq
82+
7483 def get_detector (self , id : Union [str , Detector ]) -> Detector : # pylint: disable=redefined-builtin
7584 if isinstance (id , Detector ):
7685 # Short-circuit
@@ -102,7 +111,12 @@ def create_detector(
102111 return Detector .parse_obj (obj .to_dict ())
103112
104113 def get_or_create_detector (
105- self , name : str , query : str , * , confidence_threshold : Optional [float ] = None , config_name : Optional [str ] = None
114+ self ,
115+ name : str ,
116+ query : str ,
117+ * ,
118+ confidence_threshold : Optional [float ] = None ,
119+ config_name : Optional [str ] = None ,
106120 ) -> Detector :
107121 """Tries to look up the detector by name. If a detector with that name, query, and
108122 confidence exists, return it. Otherwise, create a detector with the specified query and
@@ -113,30 +127,41 @@ def get_or_create_detector(
113127 except NotFoundError :
114128 logger .debug (f"We could not find a detector with name='{ name } '. So we will create a new detector ..." )
115129 return self .create_detector (
116- name = name , query = query , confidence_threshold = confidence_threshold , config_name = config_name
130+ name = name ,
131+ query = query ,
132+ confidence_threshold = confidence_threshold ,
133+ config_name = config_name ,
117134 )
118135
119136 # TODO: We may soon allow users to update the retrieved detector's fields.
120137 if existing_detector .query != query :
121138 raise ValueError (
122- f"Found existing detector with name={ name } (id={ existing_detector .id } ) but the queries don't match."
123- f" The existing query is '{ existing_detector .query } '."
139+ (
140+ f"Found existing detector with name={ name } (id={ existing_detector .id } ) but the queries don't match."
141+ f" The existing query is '{ existing_detector .query } '."
142+ ),
124143 )
125144 if confidence_threshold is not None and existing_detector .confidence_threshold != confidence_threshold :
126145 raise ValueError (
127- f"Found existing detector with name={ name } (id={ existing_detector .id } ) but the confidence"
128- " thresholds don't match. The existing confidence threshold is"
129- f" { existing_detector .confidence_threshold } ."
146+ (
147+ f"Found existing detector with name={ name } (id={ existing_detector .id } ) but the confidence"
148+ " thresholds don't match. The existing confidence threshold is"
149+ f" { existing_detector .confidence_threshold } ."
150+ ),
130151 )
131152 return existing_detector
132153
133154 def get_image_query (self , id : str ) -> ImageQuery : # pylint: disable=redefined-builtin
134155 obj = self .image_queries_api .get_image_query (id = id )
135- return ImageQuery .parse_obj (obj .to_dict ())
156+ iq = ImageQuery .parse_obj (obj .to_dict ())
157+ return self ._post_process_image_query (iq )
136158
137159 def list_image_queries (self , page : int = 1 , page_size : int = 10 ) -> PaginatedImageQueryList :
138160 obj = self .image_queries_api .list_image_queries (page = page , page_size = page_size )
139- return PaginatedImageQueryList .parse_obj (obj .to_dict ())
161+ image_queries = PaginatedImageQueryList .parse_obj (obj .to_dict ())
162+ if image_queries .results is not None :
163+ image_queries .results = [self ._post_process_image_query (iq ) for iq in image_queries .results ]
164+ return image_queries
140165
141166 def submit_image_query (
142167 self ,
@@ -166,7 +191,7 @@ def submit_image_query(
166191 if wait :
167192 threshold = self .get_detector (detector ).confidence_threshold
168193 image_query = self .wait_for_confident_result (image_query , confidence_threshold = threshold , timeout_sec = wait )
169- return image_query
194+ return self . _post_process_image_query ( image_query )
170195
171196 def wait_for_confident_result (
172197 self ,
@@ -203,11 +228,11 @@ def wait_for_confident_result(
203228 image_query = self .get_image_query (image_query .id )
204229 return image_query
205230
206- def add_label (self , image_query : Union [ImageQuery , str ], label : str ):
231+ def add_label (self , image_query : Union [ImageQuery , str ], label : Union [ Label , str ] ):
207232 """A new label to an image query. This answers the detector's question.
208233 :param image_query: Either an ImageQuery object (returned from `submit_image_query`) or
209234 an image_query id as a string.
210- :param label: The string "Yes " or the string "No " in answer to the query.
235+ :param label: The string "YES " or the string "NO " in answer to the query.
211236 """
212237 if isinstance (image_query , ImageQuery ):
213238 image_query_id = image_query .id
0 commit comments