Skip to content

Commit a6e25a3

Browse files
committed
fix:Refactoring existing retry logic for aborted transactions and clean up redundant code
1 parent c634bdb commit a6e25a3

File tree

5 files changed

+127
-62
lines changed

5 files changed

+127
-62
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -464,20 +464,43 @@ def _metadata_with_prefix(prefix, **kw):
464464
return [("google-cloud-resource-prefix", prefix)]
465465

466466

467+
def _retry_on_aborted_exception(
468+
func,
469+
deadline,
470+
allowed_exceptions=None,
471+
):
472+
"""
473+
Handles retry logic for Aborted exceptions, considering the deadline.
474+
Retries the function in case of Aborted exceptions and other allowed exceptions.
475+
"""
476+
attempts = 0
477+
while True:
478+
try:
479+
attempts += 1
480+
return func()
481+
except Aborted as exc:
482+
_delay_until_retry(exc, deadline=deadline, attempts=attempts)
483+
continue
484+
except Exception as exc:
485+
try:
486+
retry_result = _retry(func=func, allowed_exceptions=allowed_exceptions)
487+
if retry_result is not None:
488+
return retry_result
489+
else:
490+
raise exc
491+
except Aborted:
492+
continue
493+
494+
467495
def _retry(
468496
func,
469497
retry_count=5,
470498
delay=2,
471499
allowed_exceptions=None,
472500
beforeNextRetry=None,
473-
deadline=None,
474501
):
475502
"""
476-
Retry a specified function with different logic based on the type of exception raised.
477-
478-
If the exception is of type google.api_core.exceptions.Aborted,
479-
apply an alternate retry strategy that relies on the provided deadline value instead of a fixed number of retries.
480-
For all other exceptions, retry the function up to a specified number of times.
503+
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
481504
482505
Args:
483506
func: The function to be retried.
@@ -491,21 +514,13 @@ def _retry(
491514
The result of the function if it is successful, or raises the last exception if all retries fail.
492515
"""
493516
retries = 0
494-
while True:
517+
while retries <= retry_count:
495518
if retries > 0 and beforeNextRetry:
496519
beforeNextRetry(retries, delay)
497520

498521
try:
499522
return func()
500523
except Exception as exc:
501-
if isinstance(exc, Aborted) and deadline is not None:
502-
if (
503-
allowed_exceptions is not None
504-
and allowed_exceptions.get(exc.__class__) is not None
505-
):
506-
retries += 1
507-
_delay_until_retry(exc, deadline=deadline, attempts=retries)
508-
continue
509524
if (
510525
allowed_exceptions is None or exc.__class__ in allowed_exceptions
511526
) and retries < retry_count:
@@ -568,7 +583,6 @@ def _delay_until_retry(exc, deadline, attempts):
568583
raise
569584

570585
delay = _get_retry_delay(cause, attempts)
571-
print(now, delay, deadline)
572586
if delay is not None:
573587
if now + delay > deadline:
574588
raise

google/cloud/spanner_v1/batch.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
3030
from google.cloud.spanner_v1 import RequestOptions
3131
from google.cloud.spanner_v1._helpers import _retry
32+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
3233
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
3334
from google.api_core.exceptions import InternalServerError
34-
from google.api_core.exceptions import Aborted
3535
import time
3636

3737
DEFAULT_RETRY_TIMEOUT_SECS = 30
@@ -235,11 +235,10 @@ def commit(
235235
deadline = time.time() + kwargs.get(
236236
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
237237
)
238-
response = _retry(
238+
response = _retry_on_aborted_exception(
239239
method,
240240
allowed_exceptions={
241241
InternalServerError: _check_rst_stream_error,
242-
Aborted: no_op_handler,
243242
},
244243
deadline=deadline,
245244
)
@@ -360,16 +359,11 @@ def batch_write(
360359
request=request,
361360
metadata=metadata,
362361
)
363-
deadline = time.time() + kwargs.get(
364-
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
365-
)
366362
response = _retry(
367363
method,
368364
allowed_exceptions={
369365
InternalServerError: _check_rst_stream_error,
370-
Aborted: no_op_handler,
371366
},
372-
deadline=deadline,
373367
)
374368
self.committed = True
375369
return response
@@ -393,8 +387,3 @@ def _make_write_pb(table, columns, values):
393387
return Mutation.Write(
394388
table=table, columns=columns, values=_make_list_value_pbs(values)
395389
)
396-
397-
398-
def no_op_handler(exc):
399-
# No-op (does nothing)
400-
pass

tests/unit/test__helpers.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,98 @@ def test_check_rst_stream_error(self):
882882

883883
self.assertEqual(test_api.test_fxn.call_count, 3)
884884

885+
def test_retry_on_aborted_exception_with_success_after_first_aborted_retry(self):
886+
from google.api_core.exceptions import Aborted
887+
import time
888+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
889+
import functools
890+
891+
test_api = mock.create_autospec(self.test_class)
892+
test_api.test_fxn.side_effect = [
893+
Aborted("aborted exception", errors=("Aborted error")),
894+
"true",
895+
]
896+
deadline = time.time() + 30
897+
result_after_retry = _retry_on_aborted_exception(
898+
functools.partial(test_api.test_fxn), deadline
899+
)
900+
901+
self.assertEqual(test_api.test_fxn.call_count, 2)
902+
self.assertTrue(result_after_retry)
903+
904+
def test_retry_on_aborted_exception_with_success_after_three_retries(self):
905+
from google.api_core.exceptions import Aborted
906+
from google.api_core.exceptions import InternalServerError
907+
import time
908+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
909+
import functools
910+
911+
test_api = mock.create_autospec(self.test_class)
912+
# Case where aborted exception is thrown after other generic exceptions
913+
test_api.test_fxn.side_effect = [
914+
InternalServerError("testing"),
915+
InternalServerError("testing"),
916+
Aborted("aborted exception", errors=("Aborted error")),
917+
"true",
918+
]
919+
allowed_exceptions = {
920+
InternalServerError: lambda exc: None,
921+
}
922+
deadline = time.time() + 30
923+
_retry_on_aborted_exception(
924+
functools.partial(test_api.test_fxn),
925+
deadline=deadline,
926+
allowed_exceptions=allowed_exceptions,
927+
)
928+
929+
self.assertEqual(test_api.test_fxn.call_count, 4)
930+
931+
def test_retry_on_aborted_exception_raises_aborted_if_deadline_expires(self):
932+
from google.api_core.exceptions import Aborted
933+
import time
934+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
935+
import functools
936+
937+
test_api = mock.create_autospec(self.test_class)
938+
test_api.test_fxn.side_effect = [
939+
Aborted("aborted exception", errors=("Aborted error")),
940+
"true",
941+
]
942+
deadline = time.time() + 0.1
943+
with self.assertRaises(Aborted):
944+
_retry_on_aborted_exception(
945+
functools.partial(test_api.test_fxn), deadline=deadline
946+
)
947+
948+
self.assertEqual(test_api.test_fxn.call_count, 1)
949+
950+
def test_retry_on_aborted_exception_returns_response_after_internal_server_errors(
951+
self,
952+
):
953+
from google.api_core.exceptions import InternalServerError
954+
import time
955+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
956+
import functools
957+
958+
test_api = mock.create_autospec(self.test_class)
959+
test_api.test_fxn.side_effect = [
960+
InternalServerError("testing"),
961+
InternalServerError("testing"),
962+
"true",
963+
]
964+
allowed_exceptions = {
965+
InternalServerError: lambda exc: None,
966+
}
967+
deadline = time.time() + 30
968+
result_after_retries = _retry_on_aborted_exception(
969+
functools.partial(test_api.test_fxn),
970+
deadline=deadline,
971+
allowed_exceptions=allowed_exceptions,
972+
)
973+
974+
self.assertEqual(test_api.test_fxn.call_count, 3)
975+
self.assertTrue(result_after_retries)
976+
885977

886978
class Test_metadata_with_leader_aware_routing(unittest.TestCase):
887979
def _call_fut(self, *args, **kw):

tests/unit/test_batch.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515

16-
import time
1716
import unittest
1817
from unittest.mock import MagicMock
1918
from tests._helpers import (
@@ -651,36 +650,6 @@ def __init__(self, database=None, name=TestBatch.SESSION_NAME):
651650
def session_id(self):
652651
return self.name
653652

654-
def run_in_transaction(self, fnc):
655-
"""
656-
Runs a function in a transaction, retrying if an exception occurs.
657-
:param fnc: The function to run in the transaction.
658-
:param max_retries: Maximum number of retry attempts.
659-
:param delay: Delay (in seconds) between retries.
660-
:return: The result of the function, or raises the exception after max retries.
661-
"""
662-
from google.api_core.exceptions import Aborted
663-
664-
attempt = 0
665-
max_retries = 3
666-
delay = 1
667-
while attempt < max_retries:
668-
try:
669-
result = fnc()
670-
return result
671-
except Aborted as exc:
672-
attempt += 1
673-
if attempt < max_retries:
674-
print(
675-
f"Attempt {attempt} failed with Aborted. Retrying in {delay} seconds..."
676-
)
677-
time.sleep(delay) # Wait before retrying
678-
else:
679-
raise exc # After max retries, raise the exception
680-
except Exception as exc:
681-
print(f"Unexpected exception occurred: {exc}")
682-
raise # Raise any other unexpected exception immediately
683-
684653

685654
class _Database(object):
686655
name = "testing"

tests/unit/test_database.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,12 +1931,13 @@ def test_context_mgr_w_commit_stats_error(self):
19311931
return_commit_stats=True,
19321932
request_options=RequestOptions(),
19331933
)
1934-
api.commit.assert_called_once_with(
1934+
self.assertEqual(api.commit.call_count, 2)
1935+
api.commit.assert_any_call(
19351936
request=request,
19361937
metadata=[
19371938
("google-cloud-resource-prefix", database.name),
19381939
("x-goog-spanner-route-to-leader", "true"),
1939-
],
1940+
]
19401941
)
19411942

19421943
database.logger.info.assert_not_called()

0 commit comments

Comments
 (0)