2020import logging
2121import socket
2222import time
23+ import random
2324from threading import Lock , RLock , Condition
2425import weakref
2526try :
@@ -123,6 +124,8 @@ class Host(object):
123124
124125 _currently_handling_node_up = False
125126
127+ sharding_info = None
128+
126129 def __init__ (self , endpoint , conviction_policy_factory , datacenter = None , rack = None , host_id = None ):
127130 if endpoint is None :
128131 raise ValueError ("endpoint may not be None" )
@@ -339,7 +342,6 @@ class HostConnection(object):
339342 shutdown_on_error = False
340343
341344 _session = None
342- _connection = None
343345 _lock = None
344346 _keyspace = None
345347
@@ -351,6 +353,7 @@ def __init__(self, host, host_distance, session):
351353 # this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
352354 self ._stream_available_condition = Condition (self ._lock )
353355 self ._is_replacing = False
356+ self ._connections = dict ()
354357
355358 if host_distance == HostDistance .IGNORED :
356359 log .debug ("Not opening connection to ignored host %s" , self .host )
@@ -360,18 +363,45 @@ def __init__(self, host, host_distance, session):
360363 return
361364
362365 log .debug ("Initializing connection for host %s" , self .host )
363- self ._connection = session .cluster .connection_factory (host .endpoint )
366+ first_connection = session .cluster .connection_factory (host .endpoint )
367+ log .debug ("first connection created for shard_id=%i" , first_connection .shard_id )
368+ self ._connections [first_connection .shard_id ] = first_connection
364369 self ._keyspace = session .keyspace
370+
365371 if self ._keyspace :
366- self ._connection .set_keyspace_blocking (self ._keyspace )
372+ first_connection .set_keyspace_blocking (self ._keyspace )
373+
374+ if first_connection .sharding_info :
375+ self .host .sharding_info = weakref .proxy (first_connection .sharding_info )
376+ for _ in range (first_connection .sharding_info .shards_count * 2 ):
377+ conn = self ._session .cluster .connection_factory (self .host .endpoint )
378+ if conn .shard_id not in self ._connections .keys ():
379+ log .debug ("new connection created for shard_id=%i" , conn .shard_id )
380+ self ._connections [conn .shard_id ] = conn
381+ if self ._keyspace :
382+ self ._connections [conn .shard_id ].set_keyspace_blocking (self ._keyspace )
383+
384+ if len (self ._connections .keys ()) == first_connection .sharding_info .shards_count :
385+ break
386+ if not len (self ._connections .keys ()) == first_connection .sharding_info .shards_count :
387+ raise NoConnectionsAvailable ("not enough shard connection opened" )
388+
367389 log .debug ("Finished initializing connection for host %s" , self .host )
368390
369- def borrow_connection (self , timeout ):
391+ def borrow_connection (self , timeout , routing_key = None ):
370392 if self .is_shutdown :
371393 raise ConnectionException (
372394 "Pool for %s is shutdown" % (self .host ,), self .host )
373395
374- conn = self ._connection
396+ shard_id = 0
397+ if self .host .sharding_info :
398+ if routing_key :
399+ t = self ._session .cluster .metadata .token_map .token_class .from_key (routing_key )
400+ shard_id = self .host .sharding_info .shard_id (t )
401+ else :
402+ shard_id = random .randint (0 , self .host .sharding_info .shards_count - 1 )
403+
404+ conn = self ._connections .get (shard_id )
375405 if not conn :
376406 raise NoConnectionsAvailable ()
377407
@@ -416,7 +446,7 @@ def return_connection(self, connection):
416446 if is_down :
417447 self .shutdown ()
418448 else :
419- self ._connection = None
449+ del self ._connections [ connection . shard_id ]
420450 with self ._lock :
421451 if self ._is_replacing :
422452 return
@@ -433,7 +463,7 @@ def _replace(self, connection):
433463 conn = self ._session .cluster .connection_factory (self .host .endpoint )
434464 if self ._keyspace :
435465 conn .set_keyspace_blocking (self ._keyspace )
436- self ._connection = conn
466+ self ._connections [ connection . shard_id ] = conn
437467 except Exception :
438468 log .warning ("Failed reconnecting %s. Retrying." % (self .host .endpoint ,))
439469 self ._session .submit (self ._replace , connection )
@@ -450,36 +480,48 @@ def shutdown(self):
450480 self .is_shutdown = True
451481 self ._stream_available_condition .notify_all ()
452482
453- if self ._connection :
454- self ._connection .close ()
455- self ._connection = None
483+ if self ._connections :
484+ for c in self ._connections .values ():
485+ c .close ()
486+ self ._connections = dict ()
456487
457488 def _set_keyspace_for_all_conns (self , keyspace , callback ):
458- if self .is_shutdown or not self ._connection :
489+ """
490+ Asynchronously sets the keyspace for all connections. When all
491+ connections have been set, `callback` will be called with two
492+ arguments: this pool, and a list of any errors that occurred.
493+ """
494+ remaining_callbacks = set (self ._connections .values ())
495+ errors = []
496+
497+ if not remaining_callbacks :
498+ callback (self , errors )
459499 return
460500
461501 def connection_finished_setting_keyspace (conn , error ):
462502 self .return_connection (conn )
463- errors = [] if not error else [error ]
464- callback (self , errors )
503+ remaining_callbacks .remove (conn )
504+ if error :
505+ errors .append (error )
506+
507+ if not remaining_callbacks :
508+ callback (self , errors )
465509
466510 self ._keyspace = keyspace
467- self ._connection .set_keyspace_async (keyspace , connection_finished_setting_keyspace )
511+ for conn in self ._connections .values ():
512+ conn .set_keyspace_async (keyspace , connection_finished_setting_keyspace )
468513
469514 def get_connections (self ):
470- c = self ._connection
471- return [ c ] if c else []
515+ c = self ._connections
516+ return list ( self . _connections . values ()) if c else []
472517
473518 def get_state (self ):
474- connection = self ._connection
475- open_count = 1 if connection and not (connection .is_closed or connection .is_defunct ) else 0
476- in_flights = [connection .in_flight ] if connection else []
477- return {'shutdown' : self .is_shutdown , 'open_count' : open_count , 'in_flights' : in_flights }
519+ in_flights = [c .in_flight for c in self ._connections .values ()]
520+ return {'shutdown' : self .is_shutdown , 'open_count' : self .open_count , 'in_flights' : in_flights }
478521
479522 @property
480523 def open_count (self ):
481- connection = self ._connection
482- return 1 if connection and not (connection .is_closed or connection .is_defunct ) else 0
524+ return sum ([1 if c and not (c .is_closed or c .is_defunct ) else 0 for c in self ._connections .values ()])
483525
484526_MAX_SIMULTANEOUS_CREATION = 1
485527_MIN_TRASH_INTERVAL = 10
@@ -522,7 +564,7 @@ def __init__(self, host, host_distance, session):
522564 self .open_count = core_conns
523565 log .debug ("Finished initializing new connection pool for host %s" , self .host )
524566
525- def borrow_connection (self , timeout ):
567+ def borrow_connection (self , timeout , routing_key = None ):
526568 if self .is_shutdown :
527569 raise ConnectionException (
528570 "Pool for %s is shutdown" % (self .host ,), self .host )
0 commit comments