Skip to content

Commit 687e0e3

Browse files
committed
fix: updating retry strategy to call _retry_on_aborted_exception for batch api calls
1 parent 0c88b16 commit 687e0e3

File tree

4 files changed

+122
-31
lines changed

4 files changed

+122
-31
lines changed

google/cloud/spanner_v1/_helpers.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -464,20 +464,32 @@ def _metadata_with_prefix(prefix, **kw):
464464
return [("google-cloud-resource-prefix", prefix)]
465465

466466

467-
def _retry_on_aborted_exception(exc, deadline, attempts, allowed_exceptions):
467+
def _retry_on_aborted_exception(
468+
func,
469+
deadline,
470+
allowed_exceptions=None,
471+
):
468472
"""
469-
Handles the retry logic for Aborted exceptions, considering the deadline.
470-
Returns True if the exception is retried, False otherwise.
473+
Handles retry logic for Aborted exceptions, considering the deadline.
474+
Retries the function in case of Aborted exceptions and other allowed exceptions.
471475
"""
472-
if isinstance(exc, Aborted) and deadline is not None:
473-
# The logic for handling Aborted exceptions
474-
if (
475-
allowed_exceptions is not None
476-
and allowed_exceptions.get(exc.__class__) is not None
477-
):
476+
attempts = 0
477+
while True:
478+
try:
479+
attempts += 1
480+
return func()
481+
except Aborted as exc:
478482
_delay_until_retry(exc, deadline=deadline, attempts=attempts)
479-
return True
480-
return False
483+
continue
484+
except Exception as exc:
485+
try:
486+
retry_result = _retry(func=func, allowed_exceptions=allowed_exceptions)
487+
if retry_result is not None:
488+
return retry_result
489+
else:
490+
raise exc
491+
except Aborted:
492+
continue
481493

482494

483495
def _retry(
@@ -486,14 +498,9 @@ def _retry(
486498
delay=2,
487499
allowed_exceptions=None,
488500
beforeNextRetry=None,
489-
deadline=None,
490501
):
491502
"""
492-
Retry a specified function with different logic based on the type of exception raised.
493-
494-
If the exception is of type google.api_core.exceptions.Aborted,
495-
apply an alternate retry strategy that relies on the provided deadline value instead of a fixed number of retries.
496-
For all other exceptions, retry the function up to a specified number of times.
503+
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
497504
498505
Args:
499506
func: The function to be retried.
@@ -507,19 +514,13 @@ def _retry(
507514
The result of the function if it is successful, or raises the last exception if all retries fail.
508515
"""
509516
retries = 0
510-
attempts = 0
511-
while True:
512-
if retries > retry_count:
513-
raise Exception("Exceeded retry count.")
517+
while retries <= retry_count:
514518
if retries > 0 and beforeNextRetry:
515519
beforeNextRetry(retries, delay)
516520

517521
try:
518-
attempts += 1
519522
return func()
520523
except Exception as exc:
521-
if _retry_on_aborted_exception(exc, deadline, attempts, allowed_exceptions):
522-
continue
523524
if (
524525
allowed_exceptions is None or exc.__class__ in allowed_exceptions
525526
) and retries < retry_count:
@@ -582,7 +583,6 @@ def _delay_until_retry(exc, deadline, attempts):
582583
raise
583584

584585
delay = _get_retry_delay(cause, attempts)
585-
print(now, delay, deadline)
586586
if delay is not None:
587587
if now + delay > deadline:
588588
raise

google/cloud/spanner_v1/batch.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
3030
from google.cloud.spanner_v1 import RequestOptions
3131
from google.cloud.spanner_v1._helpers import _retry
32+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
3233
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
3334
from google.api_core.exceptions import InternalServerError
34-
from google.api_core.exceptions import Aborted
3535
import time
3636

3737
DEFAULT_RETRY_TIMEOUT_SECS = 30
@@ -235,11 +235,10 @@ def commit(
235235
deadline = time.time() + kwargs.get(
236236
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
237237
)
238-
response = _retry(
238+
response = _retry_on_aborted_exception(
239239
method,
240240
allowed_exceptions={
241241
InternalServerError: _check_rst_stream_error,
242-
Aborted: lambda exc: None,
243242
},
244243
deadline=deadline,
245244
)
@@ -388,4 +387,3 @@ def _make_write_pb(table, columns, values):
388387
return Mutation.Write(
389388
table=table, columns=columns, values=_make_list_value_pbs(values)
390389
)
391-

tests/unit/test__helpers.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,98 @@ def test_check_rst_stream_error(self):
882882

883883
self.assertEqual(test_api.test_fxn.call_count, 3)
884884

885+
def test_retry_on_aborted_exception_with_success_after_first_aborted_retry(self):
886+
from google.api_core.exceptions import Aborted
887+
import time
888+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
889+
import functools
890+
891+
test_api = mock.create_autospec(self.test_class)
892+
test_api.test_fxn.side_effect = [
893+
Aborted("aborted exception", errors=("Aborted error")),
894+
"true",
895+
]
896+
deadline = time.time() + 30
897+
result_after_retry = _retry_on_aborted_exception(
898+
functools.partial(test_api.test_fxn), deadline
899+
)
900+
901+
self.assertEqual(test_api.test_fxn.call_count, 2)
902+
self.assertTrue(result_after_retry)
903+
904+
def test_retry_on_aborted_exception_with_success_after_three_retries(self):
905+
from google.api_core.exceptions import Aborted
906+
from google.api_core.exceptions import InternalServerError
907+
import time
908+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
909+
import functools
910+
911+
test_api = mock.create_autospec(self.test_class)
912+
# Case where aborted exception is thrown after other generic exceptions
913+
test_api.test_fxn.side_effect = [
914+
InternalServerError("testing"),
915+
InternalServerError("testing"),
916+
Aborted("aborted exception", errors=("Aborted error")),
917+
"true",
918+
]
919+
allowed_exceptions = {
920+
InternalServerError: lambda exc: None,
921+
}
922+
deadline = time.time() + 30
923+
_retry_on_aborted_exception(
924+
functools.partial(test_api.test_fxn),
925+
deadline=deadline,
926+
allowed_exceptions=allowed_exceptions,
927+
)
928+
929+
self.assertEqual(test_api.test_fxn.call_count, 4)
930+
931+
def test_retry_on_aborted_exception_raises_aborted_if_deadline_expires(self):
932+
from google.api_core.exceptions import Aborted
933+
import time
934+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
935+
import functools
936+
937+
test_api = mock.create_autospec(self.test_class)
938+
test_api.test_fxn.side_effect = [
939+
Aborted("aborted exception", errors=("Aborted error")),
940+
"true",
941+
]
942+
deadline = time.time() + 0.1
943+
with self.assertRaises(Aborted):
944+
_retry_on_aborted_exception(
945+
functools.partial(test_api.test_fxn), deadline=deadline
946+
)
947+
948+
self.assertEqual(test_api.test_fxn.call_count, 1)
949+
950+
def test_retry_on_aborted_exception_returns_response_after_internal_server_errors(
951+
self,
952+
):
953+
from google.api_core.exceptions import InternalServerError
954+
import time
955+
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
956+
import functools
957+
958+
test_api = mock.create_autospec(self.test_class)
959+
test_api.test_fxn.side_effect = [
960+
InternalServerError("testing"),
961+
InternalServerError("testing"),
962+
"true",
963+
]
964+
allowed_exceptions = {
965+
InternalServerError: lambda exc: None,
966+
}
967+
deadline = time.time() + 30
968+
result_after_retries = _retry_on_aborted_exception(
969+
functools.partial(test_api.test_fxn),
970+
deadline=deadline,
971+
allowed_exceptions=allowed_exceptions,
972+
)
973+
974+
self.assertEqual(test_api.test_fxn.call_count, 3)
975+
self.assertTrue(result_after_retries)
976+
885977

886978
class Test_metadata_with_leader_aware_routing(unittest.TestCase):
887979
def _call_fut(self, *args, **kw):

tests/unit/test_database.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,12 +1931,13 @@ def test_context_mgr_w_commit_stats_error(self):
19311931
return_commit_stats=True,
19321932
request_options=RequestOptions(),
19331933
)
1934-
api.commit.assert_called_once_with(
1934+
self.assertEqual(api.commit.call_count, 2)
1935+
api.commit.assert_any_call(
19351936
request=request,
19361937
metadata=[
19371938
("google-cloud-resource-prefix", database.name),
19381939
("x-goog-spanner-route-to-leader", "true"),
1939-
],
1940+
]
19401941
)
19411942

19421943
database.logger.info.assert_not_called()

0 commit comments

Comments
 (0)