diff --git a/.github/workflows/cibuildwheel.yml b/.github/workflows/cibuildwheel.yml index 378d0c8c..87198b98 100644 --- a/.github/workflows/cibuildwheel.yml +++ b/.github/workflows/cibuildwheel.yml @@ -60,3 +60,9 @@ jobs: working-directory: ${{ runner.temp }} run: python -m unittest discover -s ${GITHUB_WORKSPACE}/bindings/python + - name: Run examples + env: + PYTHONPATH: ${{ runner.temp }}/usr + CTEST_OUTPUT_ON_FAILURE: 1 + working-directory: ${{ runner.temp }} + run: python -m unittest discover -p "example*.py" -s ${GITHUB_WORKSPACE}/examples/python diff --git a/examples/python/example_vamana.py b/examples/python/example_vamana.py index b988a07c..b8346bf9 100644 --- a/examples/python/example_vamana.py +++ b/examples/python/example_vamana.py @@ -21,11 +21,12 @@ # [imports] DEBUG_MODE = False -def assert_equal(lhs, rhs, message: str = ""): +def assert_equal(lhs, rhs, message: str = "", epsilon = 0.05): if DEBUG_MODE: print(f"{message}: {lhs} == {rhs}") else: - assert lhs == rhs, message + assert lhs < rhs + epsilon, message + assert lhs > rhs - epsilon, message def run_test_float(index, queries, groundtruth): expected = { @@ -79,7 +80,6 @@ def run_test_build_two_level4_8(index, queries, groundtruth): test_data_dir = None def run(): - # ### # Generating test data # ### @@ -159,7 +159,7 @@ def run(): # Compare with the groundtruth. recall = svs.k_recall_at(groundtruth, I, 10, 10) print(f"Recall = {recall}") - assert(recall == 0.8288) + assert_equal(recall, 0.8288) # [perform-queries] # [search-window-size] @@ -213,7 +213,7 @@ def run(): # Compare with the groundtruth. recall = svs.k_recall_at(groundtruth, I, 10, 10) print(f"Recall = {recall}") - assert(recall == 0.8288) + assert_equal(recall, 0.8288) # [loading] ##### Begin Test diff --git a/examples/python/example_vamana_dynamic.py b/examples/python/example_vamana_dynamic.py index 3d57bd12..45f087b2 100644 --- a/examples/python/example_vamana_dynamic.py +++ b/examples/python/example_vamana_dynamic.py @@ -22,11 +22,12 @@ # [imports] DEBUG_MODE = False -def assert_equal(lhs, rhs, message: str = ""): +def assert_equal(lhs, rhs, message: str = "", epsilon = 0.05): if DEBUG_MODE: print(f"{message}: {lhs} == {rhs}") else: - assert lhs == rhs, message + assert lhs < rhs + epsilon, message + assert lhs > rhs - epsilon, message def run_test_float(index, queries, groundtruth): expected = { @@ -118,7 +119,7 @@ def run(): # Compare with the groundtruth. recall = svs.k_recall_at(groundtruth, I, 10, 10) print(f"Recall = {recall}") - assert(recall == 0.8202) + assert_equal(recall, 0.8202) # [perform-queries] ##### Begin Test @@ -158,8 +159,7 @@ def run(): # Compare with the groundtruth. recall = svs.k_recall_at(groundtruth, I, 10, 10) print(f"Recall = {recall}") - assert(recall == 0.8202) - + assert_equal(recall, 0.8202) ##### Begin Test run_test_float(index, queries, groundtruth)