Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 68 additions & 12 deletions src/svsbench/generate_ground_truth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
parser.add_argument("--vecs_file", help="Vectors *vecs file", type=Path)
parser.add_argument("--query_file", help="Query vectors file", type=Path)
parser.add_argument("--out_file", help="Output file", type=Path)
parser.add_argument(
"--query_out_file",
help="Output file for query vectors generated when num_queries given",
type=Path,
)
parser.add_argument(
"--distance",
help="Distance",
Expand All @@ -33,6 +38,14 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
"-k", help="Number of neighbors", type=int, default=100
)
parser.add_argument("--num_vectors", help="Number of vectors", type=int)
parser.add_argument(
"--num_query_vectors",
help="Number of query vectors."
" If given, query vectors will be shuffled."
" If more than in the query file, the query vectors will be shuffled"
" and repeated as needed.",
type=int,
)
parser.add_argument(
"--shuffle", help="Shuffle order of vectors", action="store_true"
)
Expand All @@ -54,8 +67,10 @@ def main(argv: str | None = None) -> None:
k=args.k,
num_threads=args.max_threads,
out_file=args.out_file,
query_out_path=args.query_out_file,
shuffle=args.shuffle,
seed=args.seed,
num_query_vectors=args.num_query_vectors,
)


Expand All @@ -64,31 +79,72 @@ def generate_ground_truth(
vecs_path: Path,
query_file: Path,
distance: svs.DistanceType,
num_vectors: int | None,
num_vectors: int | None = None,
k: int = 100,
num_threads: int = 1,
out_file: Path | None = None,
query_out_path: Path | None = None,
shuffle: bool = False,
seed: int = 42,
num_query_vectors: int | None = None,
) -> None:
if out_file is None:
out_file = utils.ground_truth_path(
vecs_path, query_file, distance, num_vectors, seed if shuffle else None,
if out_file is not None and out_file.suffix != ".ivecs":
raise SystemExit("Error: --out_file must end in .ivecs")
if (
query_out_path is not None
and query_out_path.suffix != query_file.suffix
):
raise SystemExit(
"Error: --query_out_path must have the same suffix as --query_file"
)
else:
if out_file.suffix != ".ivecs":
raise SystemExit("Error: --out_file must end in .ivecs")
out_file = str(out_file)
queries = svs.read_vecs(str(query_file))
vectors = svs.read_vecs(str(vecs_path))
if num_vectors is None:
num_vectors = vectors.shape[0]
# If num_vectors is None or larger than the number of vectors,
# slicing will return the whole array.
vectors = vectors[:num_vectors]
if shuffle:
vectors = vectors[np.random.default_rng(seed).permutation(num_vectors)]
np.random.default_rng(seed).shuffle(vectors)
index = svs.Flat(vectors, distance=distance, num_threads=num_threads)
idxs, _ = index.search(queries, k)
svs.write_vecs(idxs.astype(np.uint32), out_file)
if num_query_vectors is not None:
queries_all = np.empty_like(
queries, shape=(num_query_vectors, queries.shape[1])
)
ground_truth_all = np.empty_like(
idxs, shape=(num_query_vectors, idxs.shape[1])
)
rng = np.random.default_rng(seed)
cursor = 0
while cursor < num_query_vectors:
permutation = rng.permutation(len(queries))
batch_size = min(num_query_vectors - cursor, len(queries))
queries_all[cursor : cursor + batch_size] = queries[
permutation[:batch_size]
]
ground_truth_all[cursor : cursor + batch_size] = idxs[
permutation[:batch_size]
]
Comment on lines +117 to +126
Copy link

Copilot AI Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding inline comments within the while-loop block that repeats queries to clarify its logic and purpose, which can help future readers understand the batch-based repetition process.

Suggested change
cursor = 0
while cursor < num_query_vectors:
permutation = rng.permutation(len(queries))
batch_size = min(num_query_vectors - cursor, len(queries))
queries_all[cursor : cursor + batch_size] = queries[
permutation[:batch_size]
]
ground_truth_all[cursor : cursor + batch_size] = idxs[
permutation[:batch_size]
]
cursor = 0
# Repeat the process until we have generated the required number of query vectors.
while cursor < num_query_vectors:
# Generate a random permutation of the query indices to shuffle the queries.
permutation = rng.permutation(len(queries))
# Determine the size of the current batch, ensuring we don't exceed the total required.
batch_size = min(num_query_vectors - cursor, len(queries))
# Select a batch of queries based on the permutation and add them to the output array.
queries_all[cursor : cursor + batch_size] = queries[
permutation[:batch_size]
]
# Select the corresponding ground truth indices for the batch and add them to the output array.
ground_truth_all[cursor : cursor + batch_size] = idxs[
permutation[:batch_size]
]
# Update the cursor to reflect the number of queries processed so far.

Copilot uses AI. Check for mistakes.

cursor += batch_size
if query_out_path is None:
query_out_path = (
query_file.parent
/ f"{query_file.stem}-{num_query_vectors}_{seed}"
f"{query_file.suffix}"
)
svs.write_vecs(queries_all, str(query_out_path))
queries_path = query_out_path
else:
queries_path = query_file
ground_truth_all = idxs
if out_file is None:
out_file = utils.ground_truth_path(
vecs_path,
queries_path,
distance,
num_vectors,
seed if shuffle else None,
)
svs.write_vecs(ground_truth_all.astype(np.uint32), str(out_file))
logger.info({"ground_truth_saved": out_file})


Expand Down
21 changes: 14 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@
8: svs.LeanVecKind.lvq8,
}

RANDOM_VECTORS_SHAPE: Final = (1000, 100)
NUM_RANDOM_QUERY_VECTORS: Final = 100
GROUND_TRUTH_K: Final = 100


def random_array(dtype: np.dtype) -> np.ndarray:
rng = np.random.default_rng(42)
if np.dtype(dtype).kind == "i":
iinfo = np.iinfo(dtype)
return rng.integers(iinfo.min, iinfo.max, (1000, 100), dtype=dtype)
return rng.integers(
iinfo.min, iinfo.max, RANDOM_VECTORS_SHAPE, dtype=dtype
)
else:
return rng.random((1000, 100)).astype(dtype)
return rng.random(RANDOM_VECTORS_SHAPE).astype(dtype)


@pytest.fixture(
scope="session", params=consts.SUFFIX_TO_SVS_TYPE.keys()
)
@pytest.fixture(scope="session", params=consts.SUFFIX_TO_SVS_TYPE.keys())
def tmp_vecs(request, tmp_path_factory):
suffix = request.param
vecs_path = tmp_path_factory.mktemp("vecs") / ("random" + suffix)
Expand Down Expand Up @@ -134,7 +138,10 @@ def index_dir_with_svs_type_and_dynamic(request, tmp_path_factory):
def query_path(tmp_path_factory) -> Path:
path = tmp_path_factory.mktemp("query") / "query.fvecs"
svs.write_vecs(
np.random.default_rng(42).random((100, 100)).astype(np.float32), path
np.random.default_rng(42)
.random((NUM_RANDOM_QUERY_VECTORS, RANDOM_VECTORS_SHAPE[1]))
.astype(np.float32),
path,
)
return path

Expand Down Expand Up @@ -162,7 +169,7 @@ def ground_truth_path(
)
vectors = np.load(index_dir / "data.npy")
index = svs.Flat(vectors, distance=distance, num_threads=num_threads)
idxs, _ = index.search(svs.read_vecs(str(query_path)), 100)
idxs, _ = index.search(svs.read_vecs(str(query_path)), GROUND_TRUTH_K)
ground_truth_path = (
tmp_path_factory.mktemp("ground_truth")
/ f"ground_truth_{index_svs_type}.ivecs"
Expand Down
134 changes: 134 additions & 0 deletions tests/test_generate_ground_truth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import functools

import conftest
import pytest
import svs

from svsbench.generate_ground_truth import generate_ground_truth, main


def test_generate_ground_truth_no_shuffle(
tmp_vecs, query_path, distance, num_threads, tmp_path_factory
):
if tmp_vecs.suffix == ".hvecs":
pytest.xfail("Not implemented")
out_file = tmp_path_factory.mktemp("output") / "ground_truth.ivecs"
k = 10
generate_ground_truth(
vecs_path=tmp_vecs,
query_file=query_path,
distance=distance,
num_vectors=None,
k=k,
num_threads=num_threads,
out_file=out_file,
query_out_path=None,
shuffle=False,
seed=42,
)
assert out_file.is_file()
gt = svs.read_vecs(str(out_file))
assert gt.shape == (conftest.NUM_RANDOM_QUERY_VECTORS, k), (
"Expected (num_queries, k) shape"
)


def test_generate_ground_truth_shuffle(
tmp_vecs, query_path, distance, num_threads, tmp_path_factory
):
if tmp_vecs.suffix == ".hvecs":
pytest.xfail("Not implemented")
out_file = (
tmp_path_factory.mktemp("output") / "ground_truth_shuffled.ivecs"
)
k = 5
generate_ground_truth(
vecs_path=tmp_vecs,
query_file=query_path,
distance=distance,
num_vectors=500,
k=k,
num_threads=num_threads,
out_file=out_file,
shuffle=True,
seed=2,
)
assert out_file.is_file()
gt = svs.read_vecs(str(out_file))
assert gt.shape == (conftest.NUM_RANDOM_QUERY_VECTORS, k)


def test_generate_ground_truth_num_query_vectors(
tmp_vecs, query_path, distance, num_threads, tmp_path_factory
):
k = 7
if tmp_vecs.suffix == ".hvecs":
pytest.xfail("Not implemented")
out_dir = tmp_path_factory.mktemp("output")
out_file = out_dir / "ground_truth_subqueries.ivecs"
query_out_path = out_dir / "queries_out.fvecs"
generate_ground_truth_partial = functools.partial(
generate_ground_truth,
vecs_path=tmp_vecs,
query_file=query_path,
distance=distance,
k=k,
num_threads=num_threads,
out_file=out_file,
query_out_path=query_out_path,
shuffle=True,
)
for num_query_vectors in [20, 200]:
generate_ground_truth_partial(num_query_vectors=num_query_vectors)
gt = svs.read_vecs(str(out_file))
new_queries = svs.read_vecs(str(query_out_path))
assert gt.shape == (num_query_vectors, k)
assert new_queries.shape == (
num_query_vectors,
conftest.RANDOM_VECTORS_SHAPE[1],
)


def test_generate_ground_truth_main(
tmp_vecs, query_path, num_threads, tmp_path_factory
):
if tmp_vecs.suffix == ".hvecs":
pytest.xfail("Not implemented")
out_dir = tmp_path_factory.mktemp("cli")
out_file = out_dir / "gt.ivecs"
query_out_path = out_dir / "queries_out.fvecs"
k = 8
num_query_vectors = 150

argv = [
"--vecs_file",
str(tmp_vecs),
"--query_file",
str(query_path),
"--out_file",
str(out_file),
"--distance",
"mip",
"-k",
str(k),
"--max_threads",
str(num_threads),
"--seed",
"42",
"--num_query_vectors",
str(num_query_vectors),
"--query_out_file",
str(query_out_path),
]

main(argv)

gt = svs.read_vecs(str(out_file))
new_queries = svs.read_vecs(str(query_out_path))
assert gt.shape == (num_query_vectors, k)
assert new_queries.shape == (
num_query_vectors,
conftest.RANDOM_VECTORS_SHAPE[1],
)
Loading