Skip to content

Commit 0b205bd

Browse files
Parse LWT flags when creating prepared statement
1 parent aaa7808 commit 0b205bd

File tree

5 files changed

+24
-5
lines changed

5 files changed

+24
-5
lines changed

cassandra/cluster.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3109,7 +3109,9 @@ def prepare(self, query, custom_payload=None, keyspace=None):
31093109
prepared_keyspace = keyspace if keyspace else None
31103110
prepared_statement = PreparedStatement.from_message(
31113111
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
3112-
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
3112+
self._protocol_version, response.column_metadata, response.result_metadata_id,
3113+
future._current_host.lwt_info.is_lwt(response.flags) if future._current_host.lwt_info is not None else False,
3114+
self.cluster.column_encryption_policy)
31133115
prepared_statement.custom_payload = future.custom_payload
31143116

31153117
self.cluster.add_prepared(response.query_id, prepared_statement)

cassandra/lwt_info.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,6 @@ class _LwtInfo:
1919

2020
def __init__(self, lwt_meta_bit_mask):
2121
self.lwt_meta_bit_mask = lwt_meta_bit_mask
22+
23+
def is_lwt(self, flags):
24+
return (flags & self.lwt_meta_bit_mask) == self.lwt_meta_bit_mask

cassandra/pool.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ class Host(object):
167167

168168
sharding_info = None
169169

170+
lwt_info = None
171+
170172
def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None):
171173
if endpoint is None:
172174
raise ValueError("endpoint may not be None")
@@ -438,6 +440,8 @@ def __init__(self, host, host_distance, session):
438440
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
439441
self.host.sharding_info = first_connection.features.sharding_info
440442
self._open_connections_for_all_shards(first_connection.features.shard_id)
443+
if first_connection.features.lwt_info is not None:
444+
self.host.lwt_info = first_connection.features.lwt_info
441445
self.tablets_routing_v1 = first_connection.features.tablets_routing_v1
442446

443447
log.debug("Finished initializing connection for host %s", self.host)

cassandra/protocol.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ class ResultMessage(_MessageType):
686686
bind_metadata = None
687687
pk_indexes = None
688688
schema_change_event = None
689+
flags = None
689690

690691
def __init__(self, kind):
691692
self.kind = kind
@@ -787,6 +788,7 @@ def recv_results_metadata(self, f, user_type_map):
787788

788789
def recv_prepared_metadata(self, f, protocol_version, user_type_map):
789790
flags = read_int(f)
791+
self.flags = flags
790792
colcount = read_int(f)
791793
pk_indexes = None
792794
if protocol_version >= 4:

cassandra/query.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,10 +454,11 @@ class PreparedStatement(object):
454454
routing_key_indexes = None
455455
_routing_key_index_set = None
456456
serial_consistency_level = None # TODO never used?
457+
_is_lwt = False
457458

458459
def __init__(self, column_metadata, query_id, routing_key_indexes, query,
459460
keyspace, protocol_version, result_metadata, result_metadata_id,
460-
column_encryption_policy=None):
461+
is_lwt, column_encryption_policy=None):
461462
self.column_metadata = column_metadata
462463
self.query_id = query_id
463464
self.routing_key_indexes = routing_key_indexes
@@ -468,15 +469,16 @@ def __init__(self, column_metadata, query_id, routing_key_indexes, query,
468469
self.result_metadata_id = result_metadata_id
469470
self.column_encryption_policy = column_encryption_policy
470471
self.is_idempotent = False
472+
self._is_lwt = is_lwt
471473

472474
@classmethod
473475
def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
474476
query, prepared_keyspace, protocol_version, result_metadata,
475-
result_metadata_id, column_encryption_policy=None):
477+
result_metadata_id, is_lwt, column_encryption_policy=None):
476478
if not column_metadata:
477479
return PreparedStatement(column_metadata, query_id, None,
478480
query, prepared_keyspace, protocol_version, result_metadata,
479-
result_metadata_id, column_encryption_policy)
481+
result_metadata_id, is_lwt, column_encryption_policy)
480482

481483
if pk_indexes:
482484
routing_key_indexes = pk_indexes
@@ -502,7 +504,7 @@ def from_message(cls, query_id, column_metadata, pk_indexes, cluster_metadata,
502504

503505
return PreparedStatement(column_metadata, query_id, routing_key_indexes,
504506
query, prepared_keyspace, protocol_version, result_metadata,
505-
result_metadata_id, column_encryption_policy)
507+
result_metadata_id, is_lwt, column_encryption_policy)
506508

507509
def bind(self, values):
508510
"""
@@ -517,6 +519,9 @@ def is_routing_key_index(self, i):
517519
self._routing_key_index_set = set(self.routing_key_indexes) if self.routing_key_indexes else set()
518520
return i in self._routing_key_index_set
519521

522+
def is_lwt(self):
523+
return self._is_lwt
524+
520525
def __str__(self):
521526
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
522527
return (u'<PreparedStatement query="%s", consistency=%s>' %
@@ -682,6 +687,9 @@ def routing_key(self):
682687

683688
return self._routing_key
684689

690+
def is_lwt(self):
691+
return self.prepared_statement.is_lwt
692+
685693
def __str__(self):
686694
consistency = ConsistencyLevel.value_to_name.get(self.consistency_level, 'Not Set')
687695
return (u'<BoundStatement query="%s", values=%s, consistency=%s>' %

0 commit comments

Comments
 (0)