Skip to content

Commit ab768e4

Browse files
authored
feat: support requiest options in !autocommit mode (#838)
1 parent 06725fc commit ab768e4

File tree

5 files changed

+68
-16
lines changed

5 files changed

+68
-16
lines changed

google/cloud/spanner_dbapi/_helpers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,21 @@
4747
}
4848

4949

50-
def _execute_insert_heterogenous(transaction, sql_params_list):
50+
def _execute_insert_heterogenous(
51+
transaction,
52+
sql_params_list,
53+
request_options=None,
54+
):
5155
for sql, params in sql_params_list:
5256
sql, params = sql_pyformat_args_to_spanner(sql, params)
53-
transaction.execute_update(sql, params, get_param_types(params))
57+
transaction.execute_update(
58+
sql, params, get_param_types(params), request_options=request_options
59+
)
5460

5561

5662
def handle_insert(connection, sql, params):
5763
return connection.database.run_in_transaction(
58-
_execute_insert_heterogenous, ((sql, params),)
64+
_execute_insert_heterogenous, ((sql, params),), connection.request_options
5965
)
6066

6167

google/cloud/spanner_dbapi/connection.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,21 @@ def read_only(self, value):
183183
)
184184
self._read_only = value
185185

186+
@property
187+
def request_options(self):
188+
"""Options for the next SQL operations.
189+
190+
Returns:
191+
google.cloud.spanner_v1.RequestOptions:
192+
Request options.
193+
"""
194+
if self.request_priority is None:
195+
return
196+
197+
req_opts = RequestOptions(priority=self.request_priority)
198+
self.request_priority = None
199+
return req_opts
200+
186201
@property
187202
def staleness(self):
188203
"""Current read staleness option value of this `Connection`.
@@ -437,25 +452,19 @@ def run_statement(self, statement, retried=False):
437452

438453
if statement.is_insert:
439454
_execute_insert_heterogenous(
440-
transaction, ((statement.sql, statement.params),)
455+
transaction, ((statement.sql, statement.params),), self.request_options
441456
)
442457
return (
443458
iter(()),
444459
ResultsChecksum() if retried else statement.checksum,
445460
)
446461

447-
if self.request_priority is not None:
448-
req_opts = RequestOptions(priority=self.request_priority)
449-
self.request_priority = None
450-
else:
451-
req_opts = None
452-
453462
return (
454463
transaction.execute_sql(
455464
statement.sql,
456465
statement.params,
457466
param_types=statement.param_types,
458-
request_options=req_opts,
467+
request_options=self.request_options,
459468
),
460469
ResultsChecksum() if retried else statement.checksum,
461470
)

google/cloud/spanner_dbapi/cursor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ def close(self):
172172

173173
def _do_execute_update(self, transaction, sql, params):
174174
result = transaction.execute_update(
175-
sql, params=params, param_types=get_param_types(params)
175+
sql,
176+
params=params,
177+
param_types=get_param_types(params),
178+
request_options=self.connection.request_options,
176179
)
177180
self._itr = None
178181
if type(result) == int:
@@ -278,7 +281,9 @@ def execute(self, sql, args=None):
278281
_helpers.handle_insert(self.connection, sql, args or None)
279282
else:
280283
self.connection.database.run_in_transaction(
281-
self._do_execute_update, sql, args or None
284+
self._do_execute_update,
285+
sql,
286+
args or None,
282287
)
283288
except (AlreadyExists, FailedPrecondition, OutOfRange) as e:
284289
raise IntegrityError(getattr(e, "details", e)) from e
@@ -421,7 +426,12 @@ def fetchmany(self, size=None):
421426
return items
422427

423428
def _handle_DQL_with_snapshot(self, snapshot, sql, params):
424-
self._result_set = snapshot.execute_sql(sql, params, get_param_types(params))
429+
self._result_set = snapshot.execute_sql(
430+
sql,
431+
params,
432+
get_param_types(params),
433+
request_options=self.connection.request_options,
434+
)
425435
# Read the first element so that the StreamedResultSet can
426436
# return the metadata after a DQL statement.
427437
self._itr = PeekIterator(self._result_set)

tests/unit/spanner_dbapi/test__helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def test__execute_insert_heterogenous(self):
3737

3838
mock_pyformat.assert_called_once_with(params[0], params[1])
3939
mock_param_types.assert_called_once_with(None)
40-
mock_update.assert_called_once_with(sql, None, None)
40+
mock_update.assert_called_once_with(
41+
sql, None, None, request_options=None
42+
)
4143

4244
def test__execute_insert_heterogenous_error(self):
4345
from google.cloud.spanner_dbapi import _helpers
@@ -62,7 +64,9 @@ def test__execute_insert_heterogenous_error(self):
6264

6365
mock_pyformat.assert_called_once_with(params[0], params[1])
6466
mock_param_types.assert_called_once_with(None)
65-
mock_update.assert_called_once_with(sql, None, None)
67+
mock_update.assert_called_once_with(
68+
sql, None, None, request_options=None
69+
)
6670

6771
def test_handle_insert(self):
6872
from google.cloud.spanner_dbapi import _helpers

tests/unit/spanner_dbapi/test_cursor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,29 @@ def test_handle_dql(self):
748748
self.assertIsInstance(cursor._itr, utils.PeekIterator)
749749
self.assertEqual(cursor._row_count, _UNSET_COUNT)
750750

751+
def test_handle_dql_priority(self):
752+
from google.cloud.spanner_dbapi import utils
753+
from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT
754+
from google.cloud.spanner_v1 import RequestOptions
755+
756+
connection = self._make_connection(self.INSTANCE, mock.MagicMock())
757+
connection.database.snapshot.return_value.__enter__.return_value = (
758+
mock_snapshot
759+
) = mock.MagicMock()
760+
connection.request_priority = 1
761+
762+
cursor = self._make_one(connection)
763+
764+
sql = "sql"
765+
mock_snapshot.execute_sql.return_value = ["0"]
766+
cursor._handle_DQL(sql, params=None)
767+
self.assertEqual(cursor._result_set, ["0"])
768+
self.assertIsInstance(cursor._itr, utils.PeekIterator)
769+
self.assertEqual(cursor._row_count, _UNSET_COUNT)
770+
mock_snapshot.execute_sql.assert_called_with(
771+
sql, None, None, request_options=RequestOptions(priority=1)
772+
)
773+
751774
def test_context(self):
752775
connection = self._make_connection(self.INSTANCE, self.DATABASE)
753776
cursor = self._make_one(connection)

0 commit comments

Comments
 (0)