Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions google/cloud/spanner_dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def connect(
pool=None,
user_agent=None,
client=None,
route_to_leader_enabled=False,
):
"""Creates a connection to a Google Cloud Spanner database.

Expand Down Expand Up @@ -544,6 +545,14 @@ def connect(
:class:`~google.cloud.spanner_v1.Client`.
:param client: (Optional) Custom user provided Client Object

:type route_to_leader_enabled: boolean
:param route_to_leader_enabled:
(Optional) Default False. Set route_to_leader_enabled as True to
Enable leader aware routing. Enabling leader aware routing
would route all requests in RW/PDML transactions to the
leader region.


:rtype: :class:`google.cloud.spanner_dbapi.connection.Connection`
:returns: Connection object associated with the given Google Cloud Spanner
resource.
Expand All @@ -556,11 +565,17 @@ def connect(
)
if isinstance(credentials, str):
client = spanner.Client.from_service_account_json(
credentials, project=project, client_info=client_info
credentials,
project=project,
client_info=client_info,
route_to_leader_enabled=False,
)
else:
client = spanner.Client(
project=project, credentials=credentials, client_info=client_info
project=project,
credentials=credentials,
client_info=client_info,
route_to_leader_enabled=False,
)
else:
if project is not None and client.project != project:
Expand Down
12 changes: 12 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,15 @@ def _metadata_with_prefix(prefix, **kw):
List[Tuple[str, str]]: RPC metadata with supplied prefix
"""
return [("google-cloud-resource-prefix", prefix)]


def _metadata_with_leader_aware_routing(value, **kw):
"""Create RPC metadata containing a leader aware routing header

Args:
value (bool): header value

Returns:
List[Tuple[str, str]]: RPC metadata with leader aware routing header
"""
return ("x-goog-spanner-route-to-leader", str(value).lower())
9 changes: 8 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._helpers import _make_list_value_pbs
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions

Expand Down Expand Up @@ -159,6 +162,10 @@ def commit(self, return_commit_stats=False, request_options=None):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
trace_attributes = {"num_mutations": len(self._mutations)}

Expand Down
19 changes: 19 additions & 0 deletions google/cloud/spanner_v1/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ class Client(ClientWithProject):
If a dict is provided, it must be of the same form as the protobuf
message :class:`~google.cloud.spanner_v1.types.QueryOptions`

:type route_to_leader_enabled: boolean
:param route_to_leader_enabled:
(Optional) Default False. Set route_to_leader_enabled as True to
Enable leader aware routing. Enabling leader aware routing
would route all requests in RW/PDML transactions to the
leader region.

:raises: :class:`ValueError <exceptions.ValueError>` if both ``read_only``
and ``admin`` are :data:`True`
"""
Expand All @@ -132,6 +139,7 @@ def __init__(
client_info=_CLIENT_INFO,
client_options=None,
query_options=None,
route_to_leader_enabled=False,
):
self._emulator_host = _get_spanner_emulator_host()

Expand Down Expand Up @@ -171,6 +179,8 @@ def __init__(
):
warnings.warn(_EMULATOR_HOST_HTTP_SCHEME)

self._route_to_leader_enabled = route_to_leader_enabled

@property
def credentials(self):
"""Getter for client's credentials.
Expand Down Expand Up @@ -242,6 +252,15 @@ def database_admin_api(self):
)
return self._database_admin_api

@property
def route_to_leader_enabled(self):
"""Getter for if read-write or pdml requests will be routed to leader.

:rtype: boolean
:returns: If read-write requests will be routed to leader.
"""
return self._route_to_leader_enabled

def copy(self):
"""Make a copy of this client.

Expand Down
10 changes: 9 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1 import SpannerClient
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.keyset import KeySet
from google.cloud.spanner_v1.pool import BurstyPool
Expand Down Expand Up @@ -155,6 +158,7 @@ def __init__(
self._encryption_config = encryption_config
self._database_dialect = database_dialect
self._database_role = database_role
self._route_to_leader_enabled = self._instance._client.route_to_leader_enabled

if pool is None:
pool = BurstyPool(database_role=database_role)
Expand Down Expand Up @@ -565,6 +569,10 @@ def execute_partitioned_dml(
)

metadata = _metadata_with_prefix(self.name)
if self._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(self._route_to_leader_enabled)
)

def execute_pdml():
with SessionCheckout(self._pool) as session:
Expand Down
13 changes: 12 additions & 1 deletion google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from google.cloud.exceptions import NotFound
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
from google.cloud.spanner_v1 import Session
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from warnings import warn

_NOW = datetime.datetime.utcnow # unit tests may replace
Expand Down Expand Up @@ -191,6 +194,10 @@ def bind(self, database):
self._database = database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
self._database_role = self._database_role or self._database.database_role
request = BatchCreateSessionsRequest(
database=database.name,
Expand Down Expand Up @@ -402,6 +409,10 @@ def bind(self, database):
self._database = database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
created_session_count = 0
self._database_role = self._database_role or self._database.database_role

Expand Down
17 changes: 16 additions & 1 deletion google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import CreateSessionRequest
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.snapshot import Snapshot
Expand Down Expand Up @@ -125,6 +128,12 @@ def create(self):
raise ValueError("Session ID already set by back-end")
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
if self._database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
self._database._route_to_leader_enabled
)
)

request = CreateSessionRequest(database=self._database.name)
if self._database.database_role is not None:
Expand Down Expand Up @@ -153,6 +162,12 @@ def exists(self):
return False
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
if self._database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
self._database._route_to_leader_enabled
)
)

with trace_call("CloudSpanner.GetSession", self) as span:
try:
Expand Down
29 changes: 26 additions & 3 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
from google.api_core import gapic_v1
from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import _metadata_with_prefix
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.streamed import StreamedResultSet
Expand Down Expand Up @@ -235,6 +238,10 @@ def read(
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if not self._read_only and database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

if request_options is None:
request_options = RequestOptions()
Expand All @@ -244,7 +251,7 @@ def read(
if self._read_only:
# Transaction tags are not supported for read only transactions.
request_options.transaction_tag = None
else:
elif self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

request = ReadRequest(
Expand Down Expand Up @@ -391,6 +398,10 @@ def execute_sql(

database = self._session._database
metadata = _metadata_with_prefix(database.name)
if not self._read_only and database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)

api = database.spanner_api

Expand All @@ -406,7 +417,7 @@ def execute_sql(
if self._read_only:
# Transaction tags are not supported for read only transactions.
request_options.transaction_tag = None
else:
elif self.transaction_tag is not None:
request_options.transaction_tag = self.transaction_tag

request = ExecuteSqlRequest(
Expand Down Expand Up @@ -527,6 +538,10 @@ def partition_read(
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
transaction = self._make_txn_selector()
partition_options = PartitionOptions(
partition_size_bytes=partition_size_bytes, max_partitions=max_partitions
Expand Down Expand Up @@ -621,6 +636,10 @@ def partition_query(
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
transaction = self._make_txn_selector()
partition_options = PartitionOptions(
partition_size_bytes=partition_size_bytes, max_partitions=max_partitions
Expand Down Expand Up @@ -766,6 +785,10 @@ def begin(self):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if not self._read_only and database._route_to_leader_enabled:
metadata.append(
(_metadata_with_leader_aware_routing(database._route_to_leader_enabled))
)
txn_selector = self._make_txn_selector()
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
Expand Down
24 changes: 24 additions & 0 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_make_value_pb,
_merge_query_options,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
)
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
Expand Down Expand Up @@ -50,6 +51,7 @@ class Transaction(_SnapshotBase, _BatchBase):
_multi_use = True
_execute_sql_count = 0
_lock = threading.Lock()
_read_only = False

def __init__(self, session):
if session._transaction is not None:
Expand Down Expand Up @@ -124,6 +126,10 @@ def begin(self):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
Expand All @@ -140,6 +146,12 @@ def rollback(self):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(
database._route_to_leader_enabled
)
)
with trace_call("CloudSpanner.Rollback", self._session):
api.rollback(
session=self._session.name,
Expand Down Expand Up @@ -176,6 +188,10 @@ def commit(self, return_commit_stats=False, request_options=None):
database = self._session._database
api = database.spanner_api
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
trace_attributes = {"num_mutations": len(self._mutations)}

if request_options is None:
Expand Down Expand Up @@ -294,6 +310,10 @@ def execute_update(
params_pb = self._make_params_pb(params, param_types)
database = self._session._database
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
api = database.spanner_api

seqno, self._execute_sql_count = (
Expand Down Expand Up @@ -406,6 +426,10 @@ def batch_update(self, statements, request_options=None):

database = self._session._database
metadata = _metadata_with_prefix(database.name)
if database._route_to_leader_enabled:
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
api = database.spanner_api

seqno, self._execute_sql_count = (
Expand Down
Loading