1
1
import multiprocessing as mp
2
2
from typing import List , Tuple
3
3
4
- from cassandra .cluster import Cluster , ExecutionProfile , EXEC_PROFILE_DEFAULT
5
- from cassandra .policies import DCAwareRoundRobinPolicy , TokenAwarePolicy , ExponentialReconnectionPolicy
6
4
from cassandra import ConsistencyLevel , ProtocolVersion
5
+ from cassandra .cluster import EXEC_PROFILE_DEFAULT , Cluster , ExecutionProfile
6
+ from cassandra .policies import (
7
+ DCAwareRoundRobinPolicy ,
8
+ ExponentialReconnectionPolicy ,
9
+ TokenAwarePolicy ,
10
+ )
7
11
8
12
from dataset_reader .base_reader import Query
9
13
from engine .base_client .distances import Distance
@@ -24,20 +28,22 @@ def init_client(cls, host, distance, connection_params: dict, search_params: dic
24
28
profile = ExecutionProfile (
25
29
load_balancing_policy = TokenAwarePolicy (DCAwareRoundRobinPolicy ()),
26
30
consistency_level = ConsistencyLevel .LOCAL_ONE , # Use LOCAL_ONE for faster reads
27
- request_timeout = 60
31
+ request_timeout = 60 ,
28
32
)
29
-
33
+
30
34
# Initialize Cassandra cluster connection
31
35
cls .cluster = Cluster (
32
- contact_points = [host ],
36
+ contact_points = [host ],
33
37
execution_profiles = {EXEC_PROFILE_DEFAULT : profile },
34
- reconnection_policy = ExponentialReconnectionPolicy (base_delay = 1 , max_delay = 60 ),
38
+ reconnection_policy = ExponentialReconnectionPolicy (
39
+ base_delay = 1 , max_delay = 60
40
+ ),
35
41
protocol_version = ProtocolVersion .V4 ,
36
- ** connection_params
42
+ ** connection_params ,
37
43
)
38
44
cls .session = cls .cluster .connect (CASSANDRA_KEYSPACE )
39
45
cls .search_params = search_params
40
-
46
+
41
47
# Update prepared statements with current search parameters
42
48
cls .update_prepared_statements (distance )
43
49
@@ -50,7 +56,7 @@ def update_prepared_statements(cls, distance):
50
56
"""Create prepared statements for vector searches"""
51
57
# Prepare a vector similarity search query
52
58
limit = cls .search_params .get ("top" , 10 )
53
-
59
+
54
60
if distance == Distance .COSINE :
55
61
SIMILARITY_FUNC = "similarity_cosine"
56
62
elif distance == Distance .L2 :
@@ -61,48 +67,49 @@ def update_prepared_statements(cls, distance):
61
67
raise ValueError (f"Unsupported distance metric: { distance } " )
62
68
63
69
cls .ann_search_stmt = cls .session .prepare (
64
- f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
65
- FROM { CASSANDRA_TABLE }
70
+ f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
71
+ FROM { CASSANDRA_TABLE }
66
72
ORDER BY embedding ANN OF ?
67
73
LIMIT { limit } """
68
74
)
69
-
75
+
70
76
# Prepare a statement for filtered vector search
71
- cls .filtered_search_query_template = (
72
- f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
73
- FROM { CASSANDRA_TABLE }
77
+ cls .filtered_search_query_template = f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
78
+ FROM { CASSANDRA_TABLE }
74
79
WHERE {{conditions}}
75
80
ORDER BY embedding ANN OF ?
76
81
LIMIT { limit } """
77
- )
78
82
79
83
@classmethod
80
84
def search_one (cls , query : Query , top : int ) -> List [Tuple [int , float ]]:
81
85
"""Execute a vector similarity search with optional filters"""
82
86
# Convert query vector to a format Cassandra can use
83
- query_vector = query .vector .tolist () if hasattr (query .vector , 'tolist' ) else query .vector
84
-
87
+ query_vector = (
88
+ query .vector .tolist () if hasattr (query .vector , "tolist" ) else query .vector
89
+ )
90
+
85
91
# Generate filter conditions if metadata conditions exist
86
92
filter_conditions = cls .parser .parse (query .meta_conditions )
87
-
93
+
88
94
try :
89
95
if filter_conditions :
90
96
# Use the filtered search query
91
- query_with_conditions = cls .filtered_search_query_template .format (conditions = filter_conditions )
97
+ query_with_conditions = cls .filtered_search_query_template .format (
98
+ conditions = filter_conditions
99
+ )
92
100
results = cls .session .execute (
93
101
cls .session .prepare (query_with_conditions ),
94
- (query_vector , query_vector )
102
+ (query_vector , query_vector ),
95
103
)
96
104
else :
97
105
# Use the basic ANN search query
98
106
results = cls .session .execute (
99
- cls .ann_search_stmt ,
100
- (query_vector , query_vector )
107
+ cls .ann_search_stmt , (query_vector , query_vector )
101
108
)
102
-
109
+
103
110
# Extract and return results
104
111
return [(row .id , row .distance ) for row in results ]
105
-
112
+
106
113
except Exception as ex :
107
114
print (f"Error during Cassandra vector search: { ex } " )
108
115
raise ex
@@ -113,4 +120,4 @@ def delete_client(cls):
113
120
if cls .session :
114
121
cls .session .shutdown ()
115
122
if cls .cluster :
116
- cls .cluster .shutdown ()
123
+ cls .cluster .shutdown ()
0 commit comments