Skip to content

Commit 60d42bd

Browse files
committed
refactor: Update all engines to use Query and Record dataclasses
1 parent 218c775 commit 60d42bd

File tree

13 files changed

+67
-67
lines changed

13 files changed

+67
-67
lines changed

engine/clients/elasticsearch/search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from elasticsearch import Elasticsearch
66

7+
from dataset_reader.base_reader import Query
78
from engine.base_client.search import BaseSearcher
89
from engine.clients.elasticsearch.config import (
910
ELASTIC_INDEX,
@@ -46,10 +47,10 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
4647
cls.search_params = search_params
4748

4849
@classmethod
49-
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
50+
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
5051
knn = {
5152
"field": "vector",
52-
"query_vector": vector,
53+
"query_vector": query.vector,
5354
"k": top,
5455
**{"num_candidates": 100, **cls.search_params},
5556
}

engine/clients/elasticsearch/upload.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from elasticsearch import Elasticsearch
66

7+
from dataset_reader.base_reader import Record
78
from engine.base_client.upload import BaseUploader
89
from engine.clients.elasticsearch.config import (
910
ELASTIC_INDEX,
@@ -44,19 +45,14 @@ def init_client(cls, host, distance, connection_params, upload_params):
4445
cls.upload_params = upload_params
4546

4647
@classmethod
47-
def upload_batch(
48-
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
49-
):
48+
def upload_batch(cls, batch: List[Record]):
5049
if metadata is None:
51-
metadata = [{}] * len(vectors)
50+
metadata = [{}] * len(batch)
5251
operations = []
53-
for idx, vector, payload in zip(ids, vectors, metadata):
54-
vector_id = uuid.UUID(int=idx).hex
52+
for record in batch:
53+
vector_id = uuid.UUID(int=record.idx).hex
5554
operations.append({"index": {"_id": vector_id}})
56-
if payload:
57-
operations.append({"vector": vector, **payload})
58-
else:
59-
operations.append({"vector": vector})
55+
operations.append({"vector": record.vector, **(record.metadata or {})})
6056

6157
cls.client.bulk(
6258
index=ELASTIC_INDEX,

engine/clients/milvus/search.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from pymilvus import Collection, connections
55

6+
from dataset_reader.base_reader import Query
67
from engine.base_client.search import BaseSearcher
78
from engine.clients.milvus.config import (
89
DISTANCE_MAPPING,
@@ -37,15 +38,15 @@ def get_mp_start_method(cls):
3738
return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"
3839

3940
@classmethod
40-
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
41+
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
4142
param = {"metric_type": cls.distance, "params": cls.search_params["params"]}
4243
try:
4344
res = cls.collection.search(
44-
data=[vector],
45+
data=[query.vector],
4546
anns_field="vector",
4647
param=param,
4748
limit=top,
48-
expr=cls.parser.parse(meta_conditions),
49+
expr=cls.parser.parse(query.meta_conditions),
4950
)
5051
except Exception as e:
5152
import ipdb

engine/clients/milvus/upload.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
wait_for_index_building_complete,
99
)
1010

11+
from dataset_reader.base_reader import Record
1112
from engine.base_client.upload import BaseUploader
1213
from engine.clients.milvus.config import (
1314
DISTANCE_MAPPING,
@@ -41,20 +42,22 @@ def init_client(cls, host, distance, connection_params, upload_params):
4142
cls.distance = DISTANCE_MAPPING[distance]
4243

4344
@classmethod
44-
def upload_batch(
45-
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
46-
):
47-
if metadata is not None:
45+
def upload_batch(cls, batch: List[Record]):
46+
has_metadata = any(record.metadata for record in batch)
47+
if has_metadata:
4848
field_values = [
4949
[
50-
payload.get(field_schema.name) or DTYPE_DEFAULT[field_schema.dtype]
51-
for payload in metadata
50+
record.metadata.get(field_schema.name)
51+
or DTYPE_DEFAULT[field_schema.dtype]
52+
for record in batch
5253
]
5354
for field_schema in cls.collection.schema.fields
5455
if field_schema.name not in ["id", "vector"]
5556
]
5657
else:
5758
field_values = []
59+
ids = [record.idx for record in batch]
60+
vectors = [record.vector for record in batch]
5861
cls.collection.insert([ids, vectors] + field_values)
5962

6063
@classmethod

engine/clients/opensearch/search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from opensearchpy import OpenSearch
66

7+
from dataset_reader.base_reader import Query
78
from engine.base_client.search import BaseSearcher
89
from engine.clients.opensearch.config import (
910
OPENSEARCH_INDEX,
@@ -46,11 +47,11 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
4647
cls.search_params = search_params
4748

4849
@classmethod
49-
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
50+
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
5051
query = {
5152
"knn": {
5253
"vector": {
53-
"vector": vector,
54+
"vector": query.vector,
5455
"k": top,
5556
}
5657
}

engine/clients/opensearch/upload.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import multiprocessing as mp
22
import uuid
3-
from typing import List, Optional
3+
from typing import List
44

55
from opensearchpy import OpenSearch
66

7+
from dataset_reader.base_reader import Record
78
from engine.base_client.upload import BaseUploader
89
from engine.clients.opensearch.config import (
910
OPENSEARCH_INDEX,
@@ -44,19 +45,14 @@ def init_client(cls, host, distance, connection_params, upload_params):
4445
cls.upload_params = upload_params
4546

4647
@classmethod
47-
def upload_batch(
48-
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
49-
):
48+
def upload_batch(cls, batch: List[Record]):
5049
if metadata is None:
51-
metadata = [{}] * len(vectors)
50+
metadata = [{}] * len(batch)
5251
operations = []
53-
for idx, vector, payload in zip(ids, vectors, metadata):
54-
vector_id = uuid.UUID(int=idx).hex
52+
for record in batch:
53+
vector_id = uuid.UUID(int=record.id).hex
5554
operations.append({"index": {"_id": vector_id}})
56-
if payload:
57-
operations.append({"vector": vector, **payload})
58-
else:
59-
operations.append({"vector": vector})
55+
operations.append({"vector": record.vector, **(record.metadata or {})})
6056

6157
cls.client.bulk(
6258
index=OPENSEARCH_INDEX,

engine/clients/pgvector/search.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import psycopg
66
from pgvector.psycopg import register_vector
77

8+
from dataset_reader.base_reader import Query
89
from engine.base_client.distances import Distance
910
from engine.base_client.search import BaseSearcher
1011
from engine.clients.pgvector.config import get_db_config
@@ -27,19 +28,19 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
2728
cls.search_params = search_params["search_params"]
2829

2930
@classmethod
30-
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
31+
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
3132
cls.cur.execute(f"SET hnsw.ef_search = {cls.search_params['hnsw_ef']}")
3233

3334
if cls.distance == Distance.COSINE:
34-
query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};"
35+
sql_query = f"SELECT id, embedding <=> %s AS _score FROM items ORDER BY _score LIMIT {top};"
3536
elif cls.distance == Distance.L2:
36-
query = f"SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT {top};"
37+
sql_query = f"SELECT id, embedding <-> %s AS _score FROM items ORDER BY _score LIMIT {top};"
3738
else:
3839
raise NotImplementedError(f"Unsupported distance metric {cls.distance}")
3940

4041
cls.cur.execute(
41-
query,
42-
(np.array(vector),),
42+
sql_query,
43+
(np.array(query.vector),),
4344
)
4445
return cls.cur.fetchall()
4546

engine/clients/pgvector/upload.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import psycopg
55
from pgvector.psycopg import register_vector
66

7+
from dataset_reader.base_reader import Record
78
from engine.base_client.upload import BaseUploader
89
from engine.clients.pgvector.config import get_db_config
910

@@ -21,15 +22,13 @@ def init_client(cls, host, distance, connection_params, upload_params):
2122
cls.upload_params = upload_params
2223

2324
@classmethod
24-
def upload_batch(
25-
cls, ids: List[int], vectors: List[list], metadata: Optional[List[dict]]
26-
):
25+
def upload_batch(cls, batch: List[Record]):
2726
vectors = np.array(vectors)
2827

2928
# Copy is faster than insert
3029
with cls.cur.copy("COPY items (id, embedding) FROM STDIN") as copy:
31-
for i, embedding in zip(ids, vectors):
32-
copy.write_row((i, embedding))
30+
for record in batch:
31+
copy.write_row((record.id, record.vector))
3332

3433
@classmethod
3534
def delete_client(cls):

engine/clients/qdrant/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
3535
# return "forkserver" if "forkserver" in mp.get_all_start_methods() else "spawn"
3636

3737
@classmethod
38-
def search_one(cls, query: Query, top) -> List[Tuple[int, float]]:
38+
def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]:
3939
# Can query only one till we introduce re-ranking in the benchmarks
4040
if query.sparse_vector is not None:
4141
query_vector = rest.NamedSparseVector(

engine/clients/redis/search.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import numpy as np
55
from redis import Redis, RedisCluster
6-
from redis.commands.search.query import Query
6+
from redis.commands.search.query import Query as RedisQuery
77

8+
from dataset_reader.base_reader import Query as DatasetQuery
89
from engine.base_client.search import BaseSearcher
910
from engine.clients.redis.config import (
1011
REDIS_AUTH,
@@ -41,16 +42,16 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
4142
cls._ft = cls.conns[random.randint(0, len(cls.conns)) - 1].ft()
4243

4344
@classmethod
44-
def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
45-
conditions = cls.parser.parse(meta_conditions)
45+
def search_one(cls, query: DatasetQuery, top: int) -> List[Tuple[int, float]]:
46+
conditions = cls.parser.parse(query.meta_conditions)
4647
if conditions is None:
4748
prefilter_condition = "*"
4849
params = {}
4950
else:
5051
prefilter_condition, params = conditions
5152

5253
q = (
53-
Query(
54+
RedisQuery(
5455
f"{prefilter_condition}=>[KNN $K @vector $vec_param {cls.knn_conditions} AS vector_score]"
5556
)
5657
.sort_by("vector_score", asc=True)
@@ -62,7 +63,7 @@ def search_one(cls, vector, meta_conditions, top) -> List[Tuple[int, float]]:
6263
.timeout(REDIS_QUERY_TIMEOUT)
6364
)
6465
params_dict = {
65-
"vec_param": np.array(vector).astype(np.float32).tobytes(),
66+
"vec_param": np.array(query.vector).astype(np.float32).tobytes(),
6667
"K": top,
6768
"EF": cls.search_params["search_params"]["ef"],
6869
**params,

0 commit comments

Comments
 (0)