Skip to content

Commit 535054e

Browse files
Merge pull request #27 from mihaic/calibration
Add calibration
2 parents c3d4f44 + c444cbb commit 535054e

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

engine/base_client/client.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88
from benchmark import ROOT_DIR
99
from benchmark.dataset import Dataset
10+
from dataset_reader.base_reader import BaseReader
1011
from engine.base_client.configure import BaseConfigurator
12+
from engine.base_client.distances import Distance
1113
from engine.base_client.search import BaseSearcher
1214
from engine.base_client.upload import BaseUploader
1315

@@ -168,6 +170,23 @@ def run_experiment(
168170
if filter_ef_runtime and isinstance(ef, int) and (ef not in ef_runtime):
169171
print(f"\tSkipping ef runtime: {ef}; #clients {client_count} (not in ef_runtime filter)")
170172
continue
173+
174+
if (precision := search_params.get("calibration_precision", None)) is not None:
175+
top = search_params["top"]
176+
calibration_param = search_params["calibration_param"]
177+
calibration_value, calibration_precision = calibrate(
178+
searcher,
179+
calibration_param,
180+
top,
181+
precision,
182+
dataset.config.distance,
183+
reader,
184+
)
185+
print(
186+
f"Calibrated {top=} {precision=} {calibration_value=} {calibration_precision=!s}"
187+
)
188+
searcher.search_params["search_params"][calibration_param] = calibration_value
189+
171190
for repetition in range(1, REPETITIONS + 1):
172191
print(
173192
f"\tRunning repetition {repetition} ef runtime: {ef}; #clients {client_count}"
@@ -196,3 +215,52 @@ def delete_client(self):
196215

197216
for s in self.searchers:
198217
s.delete_client()
218+
219+
def calibrate(
220+
searcher: BaseSearcher,
221+
calibration_param: str,
222+
min_value: int,
223+
precision: float,
224+
distance: Distance,
225+
reader: BaseReader,
226+
max_value: int = 1000,
227+
) -> tuple[int, float]:
228+
"""Calibrate searcher for a given precision."""
229+
if min_value > max_value:
230+
raise ValueError(
231+
f"{min_value=} cannot be greater than {max_value=}"
232+
)
233+
lower_bound = min_value
234+
upper_bound = max_value
235+
lower_bound_visited = False
236+
upper_bound_visited = False
237+
current = (lower_bound + upper_bound) // 2
238+
previous = current
239+
current_precision = 0
240+
while True:
241+
searcher.search_params["search_params"][calibration_param] = current
242+
search_stats = searcher.search_all(distance, reader.read_queries())
243+
previous_precision = current_precision
244+
current_precision = search_stats["mean_precisions"]
245+
if current_precision == precision:
246+
return current, current_precision
247+
elif current_precision > precision:
248+
upper_bound = current
249+
upper_bound_visited = True
250+
else:
251+
lower_bound = current
252+
lower_bound_visited = True
253+
next_value = (lower_bound + upper_bound) // 2
254+
if (
255+
(lower_bound_visited and next_value == lower_bound)
256+
or (upper_bound_visited and next_value == upper_bound)
257+
):
258+
if abs(previous_precision - precision) < abs(current_precision - precision):
259+
final_precision = previous_precision
260+
final_value = previous
261+
else:
262+
final_precision = current_precision
263+
final_value = current
264+
return final_value, final_precision
265+
previous = current
266+
current = next_value

0 commit comments

Comments
 (0)