Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ system_tests/local_test_setup
# Make sure a generated file isn't accidentally committed.
pylintrc
pylintrc.test


# Ignore coverage files
.coverage*
2 changes: 1 addition & 1 deletion google/cloud/spanner_dbapi/transaction_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
from google.cloud.spanner_dbapi.exceptions import RetryAborted
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1._helpers import _get_retry_delay

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection, Cursor
Expand Down
75 changes: 75 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

from google.api_core import datetime_helpers
from google.api_core.exceptions import Aborted
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.rpc.error_details_pb2 import RetryInfo

import random

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -460,6 +464,23 @@ def _metadata_with_prefix(prefix, **kw):
return [("google-cloud-resource-prefix", prefix)]


def _retry_on_aborted_exception(
func,
deadline,
):
"""
Handles retry logic for Aborted exceptions, considering the deadline.
"""
attempts = 0
while True:
try:
attempts += 1
return func()
except Aborted as exc:
_delay_until_retry(exc, deadline=deadline, attempts=attempts)
continue


def _retry(
func,
retry_count=5,
Expand Down Expand Up @@ -529,6 +550,60 @@ def _metadata_with_leader_aware_routing(value, **kw):
return ("x-goog-spanner-route-to-leader", str(value).lower())


def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.

Detect retryable abort, and impose server-supplied delay.

:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction

:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""

cause = exc.errors[0]
now = time.time()
if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.

:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction

:rtype: float
:returns: seconds to wait before retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
if hasattr(cause, "trailing_metadata"):
metadata = dict(cause.trailing_metadata())
else:
metadata = {}
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()


class AtomicCounter:
def __init__(self, start_value=0):
self.__lock = threading.Lock()
Expand Down
16 changes: 13 additions & 3 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError
import time

DEFAULT_RETRY_TIMEOUT_SECS = 30


class _BatchBase(_SessionWrapper):
Expand Down Expand Up @@ -162,6 +166,7 @@ def commit(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kwargs,
):
"""Commit mutations to the database.

Expand Down Expand Up @@ -227,9 +232,12 @@ def commit(
request=request,
metadata=metadata,
)
response = _retry(
deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
response = _retry_on_aborted_exception(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
deadline=deadline,
)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
Expand Down Expand Up @@ -348,7 +356,9 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
allowed_exceptions={
InternalServerError: _check_rst_stream_error,
},
)
self.committed = True
return response
Expand Down
10 changes: 9 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def batch(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kw,
):
"""Return an object which wraps a batch.

Expand Down Expand Up @@ -805,7 +806,11 @@ def batch(
:returns: new wrapper
"""
return BatchCheckout(
self, request_options, max_commit_delay, exclude_txn_from_change_streams
self,
request_options,
max_commit_delay,
exclude_txn_from_change_streams,
**kw,
)

def mutation_groups(self):
Expand Down Expand Up @@ -1166,6 +1171,7 @@ def __init__(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kw,
):
self._database = database
self._session = self._batch = None
Expand All @@ -1177,6 +1183,7 @@ def __init__(
self._request_options = request_options
self._max_commit_delay = max_commit_delay
self._exclude_txn_from_change_streams = exclude_txn_from_change_streams
self._kw = kw

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1197,6 +1204,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
exclude_txn_from_change_streams=self._exclude_txn_from_change_streams,
**self._kw,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
58 changes: 2 additions & 56 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Wrapper for Cloud Spanner Session objects."""

from functools import total_ordering
import random
import time
from datetime import datetime

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import GoogleAPICallError
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1 import method
from google.rpc.error_details_pb2 import RetryInfo
from google.cloud.spanner_v1._helpers import _delay_until_retry
from google.cloud.spanner_v1._helpers import _get_retry_delay

from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import CreateSessionRequest
Expand Down Expand Up @@ -554,57 +554,3 @@ def run_in_transaction(self, func, *args, **kw):
extra={"commit_stats": txn.commit_stats},
)
return return_value


# Rational: this function factors out complex shared deadline / retry
# handling from two `except:` clauses.
def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.

Detect retryable abort, and impose server-supplied delay.

:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction

:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
cause = exc.errors[0]

now = time.time()

if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.

:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction

:rtype: float
:returns: seconds to wait before retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
metadata = dict(cause.trailing_metadata())
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()
17 changes: 13 additions & 4 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,19 @@ def __create_transaction(
def Commit(self, request, context):
self._requests.append(request)
self.mock_spanner.pop_error(context)
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
del self.transactions[request.transaction_id]
if not request.transaction_id == b"":
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
tx_id = request.transaction_id
elif not request.single_use_transaction == TransactionOptions():
tx = self.__create_transaction(
request.session, request.single_use_transaction
)
tx_id = tx.id
else:
raise ValueError("Unsupported transaction type")
del self.transactions[tx_id]
return commit.CommitResponse()

def Rollback(self, request, context):
Expand Down
24 changes: 24 additions & 0 deletions tests/mockserver_tests/test_aborted_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,30 @@ def test_run_in_transaction_batch_dml_aborted(self):
self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest))
self.assertTrue(isinstance(requests[3], CommitRequest))

def test_batch_commit_aborted(self):
# Add an Aborted error for the Commit method on the mock server.
add_error(SpannerServicer.Commit.__name__, aborted_status())
with self.database.batch() as batch:
batch.insert(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(1, "Marc", "Richards"),
(2, "Catalina", "Smith"),
(3, "Alice", "Trentor"),
(4, "Lea", "Martin"),
(5, "David", "Lomond"),
],
)

# Verify that the transaction was retried.
requests = self.spanner_service.requests
self.assertEqual(3, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], CommitRequest))
# The transaction is aborted and retried.
self.assertTrue(isinstance(requests[2], CommitRequest))


def _insert_mutations(transaction: Transaction):
transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"])
Expand Down
Loading
Loading