Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,8 @@ def _profiles_without_explicit_lbps(self):

def distance(self, host):
distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values())
return HostDistance.LOCAL if HostDistance.LOCAL in distances else \
return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \
HostDistance.LOCAL if HostDistance.LOCAL in distances else \
HostDistance.REMOTE if HostDistance.REMOTE in distances else \
HostDistance.IGNORED

Expand Down Expand Up @@ -609,7 +610,7 @@ class Cluster(object):

Defaults to loopback interface.

Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit
Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit
local_dc set (as is the default), the DC is chosen from an arbitrary
host in contact_points. In this case, contact_points should contain
only nodes from a single, local DC.
Expand Down Expand Up @@ -1369,21 +1370,25 @@ def __init__(self,
self._user_types = defaultdict(dict)

self._min_requests_per_connection = {
HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS,
HostDistance.LOCAL: DEFAULT_MIN_REQUESTS,
HostDistance.REMOTE: DEFAULT_MIN_REQUESTS
}

self._max_requests_per_connection = {
HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS,
HostDistance.LOCAL: DEFAULT_MAX_REQUESTS,
HostDistance.REMOTE: DEFAULT_MAX_REQUESTS
}

self._core_connections_per_host = {
HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST
}

self._max_connections_per_host = {
HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST,
HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST
}
Expand Down
2 changes: 1 addition & 1 deletion cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3436,7 +3436,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
all_replicas = cluster.metadata.get_replicas(keyspace, routing_key)
# First check if there are local replicas
valid_replicas = [host for host in all_replicas if
host.is_up and distance(host) == HostDistance.LOCAL]
host.is_up and distance(host) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]]
if not valid_replicas:
valid_replicas = [host for host in all_replicas if host.is_up]

Expand Down
152 changes: 146 additions & 6 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,18 @@ class HostDistance(object):
connections opened to it.
"""

LOCAL = 0
LOCAL_RACK = 0
"""
Nodes with ``LOCAL_RACK`` distance will be preferred for operations
under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`)
and will have a greater number of connections opened against
them by default.

This distance is typically used for nodes within the same
datacenter and the same rack as the client.
"""

LOCAL = 1
"""
Nodes with ``LOCAL`` distance will be preferred for operations
under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
Expand All @@ -57,12 +68,12 @@ class HostDistance(object):
datacenter as the client.
"""

REMOTE = 1
REMOTE = 2
"""
Nodes with ``REMOTE`` distance will be treated as a last resort
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`)
and will have a smaller number of connections opened against
them by default.
by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`
and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of
connections opened against them by default.

This distance is typically used for nodes outside of the
datacenter that the client is running in.
Expand Down Expand Up @@ -102,6 +113,11 @@ class LoadBalancingPolicy(HostStateListener):

You may also use subclasses of :class:`.LoadBalancingPolicy` for
custom behavior.

You should always use immutable collections (e.g., tuples or
frozensets) to store information about hosts to prevent accidental
modification. When there are changes to the hosts (e.g., a host is
down or up), the old collection should be replaced with a new one.
"""

_hosts_lock = None
Expand Down Expand Up @@ -316,6 +332,130 @@ def on_add(self, host):
def on_remove(self, host):
self.on_down(host)

class RackAwareRoundRobinPolicy(LoadBalancingPolicy):
"""
Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts
in the local rack, before hosts in the local datacenter but a
different rack, before hosts in all other datercentres
"""

local_dc = None
local_rack = None
used_hosts_per_remote_dc = 0

def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0):
"""
The `local_dc` and `local_rack` parameters should be the name of the
datacenter and rack (such as is reported by ``nodetool ring``) that
should be considered local.

`used_hosts_per_remote_dc` controls how many nodes in
each remote datacenter will have connections opened
against them. In other words, `used_hosts_per_remote_dc` hosts
will be considered :attr:`~.HostDistance.REMOTE` and the
rest will be considered :attr:`~.HostDistance.IGNORED`.
By default, all remote hosts are ignored.
Comment on lines +352 to +357

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I started to think if there should be similar mechanism for racks, but I don't immediately see how would it work.

"""
self.local_rack = local_rack
self.local_dc = local_dc
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
self._live_hosts = {}
self._dc_live_hosts = {}
self._endpoints = []
self._position = 0
LoadBalancingPolicy.__init__(self)

def _rack(self, host):
return host.rack or self.local_rack

def _dc(self, host):
return host.datacenter or self.local_dc

def populate(self, cluster, hosts):
for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))):
self._live_hosts[(dc, rack)] = tuple(set(rack_hosts))
for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)):
self._dc_live_hosts[dc] = tuple(set(dc_hosts))

self._position = randint(0, len(hosts) - 1) if hosts else 0

def distance(self, host):
rack = self._rack(host)
dc = self._dc(host)
if rack == self.local_rack and dc == self.local_dc:
return HostDistance.LOCAL_RACK

if dc == self.local_dc:
return HostDistance.LOCAL

if not self.used_hosts_per_remote_dc:
return HostDistance.IGNORED

dc_hosts = self._dc_live_hosts.get(dc, ())
if not dc_hosts:
return HostDistance.IGNORED
if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc:
return HostDistance.REMOTE
else:
return HostDistance.IGNORED

def make_query_plan(self, working_keyspace=None, query=None):
pos = self._position
self._position += 1

local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ())
pos = (pos % len(local_rack_live)) if local_rack_live else 0
# Slice the cyclic iterator to start from pos and include the next len(local_live) elements
# This ensures we get exactly one full cycle starting from pos
for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)):
yield host

local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack]
pos = (pos % len(local_live)) if local_live else 0
for host in islice(cycle(local_live), pos, pos + len(local_live)):
yield host

# the dict can change, so get candidate DCs iterating over keys of a copy
for dc, remote_live in self._dc_live_hosts.copy().items():
if dc != self.local_dc:
for host in remote_live[:self.used_hosts_per_remote_dc]:
yield host

def on_up(self, host):
dc = self._dc(host)
rack = self._rack(host)
with self._hosts_lock:
current_rack_hosts = self._live_hosts.get((dc, rack), ())
if host not in current_rack_hosts:
self._live_hosts[(dc, rack)] = current_rack_hosts + (host, )
current_dc_hosts = self._dc_live_hosts.get(dc, ())
if host not in current_dc_hosts:
self._dc_live_hosts[dc] = current_dc_hosts + (host, )
Comment on lines +429 to +433

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that DCAwarePolicy does the same thing. I wonder if it is possible for on_up to be marked on already present host. It probably is, otherwise those checks wouldn't exist. Why would driver do this I have no idea.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure to be honest


def on_down(self, host):
dc = self._dc(host)
rack = self._rack(host)
with self._hosts_lock:
current_rack_hosts = self._live_hosts.get((dc, rack), ())
if host in current_rack_hosts:
hosts = tuple(h for h in current_rack_hosts if h != host)
if hosts:
self._live_hosts[(dc, rack)] = hosts
else:
del self._live_hosts[(dc, rack)]
current_dc_hosts = self._dc_live_hosts.get(dc, ())
if host in current_dc_hosts:
hosts = tuple(h for h in current_dc_hosts if h != host)
if hosts:
self._dc_live_hosts[dc] = hosts
else:
del self._dc_live_hosts[dc]

def on_add(self, host):
self.on_up(host)

def on_remove(self, host):
self.on_down(host)

class TokenAwarePolicy(LoadBalancingPolicy):
"""
Expand Down Expand Up @@ -390,7 +530,7 @@ def make_query_plan(self, working_keyspace=None, query=None):
shuffle(replicas)

for replica in replicas:
if replica.is_up and child.distance(replica) == HostDistance.LOCAL:
if replica.is_up and child.distance(replica) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]:
yield replica

for host in child.make_query_plan(keyspace, query):
Expand Down
3 changes: 3 additions & 0 deletions docs/api/cassandra/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ Load Balancing
.. autoclass:: DCAwareRoundRobinPolicy
:members:

.. autoclass:: RackAwareRoundRobinPolicy
:members:

.. autoclass:: WhiteListRoundRobinPolicy
:members:

Expand Down
89 changes: 89 additions & 0 deletions tests/integration/standard/test_rack_aware_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import logging
import unittest

from cassandra.cluster import Cluster
from cassandra.policies import ConstantReconnectionPolicy, RackAwareRoundRobinPolicy

from tests.integration import PROTOCOL_VERSION, get_cluster, use_multidc

LOGGER = logging.getLogger(__name__)

def setup_module():
use_multidc({'DC1': {'RC1': 2, 'RC2': 2}, 'DC2': {'RC1': 3}})

class RackAwareRoundRobinPolicyTests(unittest.TestCase):
@classmethod
def setup_class(cls):
cls.cluster = Cluster(contact_points=[node.address() for node in get_cluster().nodelist()], protocol_version=PROTOCOL_VERSION,
load_balancing_policy=RackAwareRoundRobinPolicy("DC1", "RC1", used_hosts_per_remote_dc=0),
reconnection_policy=ConstantReconnectionPolicy(1))
cls.session = cls.cluster.connect()
cls.create_ks_and_cf(cls)
cls.create_data(cls.session)
cls.node1, cls.node2, cls.node3, cls.node4, cls.node5, cls.node6, cls.node7 = get_cluster().nodes.values()

@classmethod
def teardown_class(cls):
cls.cluster.shutdown()

def create_ks_and_cf(self):
self.session.execute(
"""
DROP KEYSPACE IF EXISTS test1
"""
)
self.session.execute(
"""
CREATE KEYSPACE test1
WITH replication = {
'class': 'NetworkTopologyStrategy',
'replication_factor': 3
}
""")

self.session.execute(
"""
CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck));
""")

@staticmethod
def create_data(session):
prepared = session.prepare(
"""
INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?)
""")

for i in range(50):
bound = prepared.bind((i, i%5, i%2))
session.execute(bound)

def test_rack_aware(self):
prepared = self.session.prepare(
"""
SELECT pk, ck, v FROM test1.table1 WHERE pk = ?
""")

for i in range (10):
bound = prepared.bind([i])
results = self.session.execute(bound)
self.assertEqual(results, [(i, i%5, i%2)])
coordinator = str(results.response_future.coordinator_host.endpoint)
self.assertTrue(coordinator in set(["127.0.0.1:9042", "127.0.0.2:9042"]))

self.node2.stop(wait_other_notice=True, gently=True)

for i in range (10):
bound = prepared.bind([i])
results = self.session.execute(bound)
self.assertEqual(results, [(i, i%5, i%2)])
coordinator =str(results.response_future.coordinator_host.endpoint)
self.assertEqual(coordinator, "127.0.0.1:9042")

self.node1.stop(wait_other_notice=True, gently=True)

for i in range (10):
bound = prepared.bind([i])
results = self.session.execute(bound)
self.assertEqual(results, [(i, i%5, i%2)])
coordinator = str(results.response_future.coordinator_host.endpoint)
self.assertTrue(coordinator in set(["127.0.0.3:9042", "127.0.0.4:9042"]))
Loading