Skip to content
12 changes: 3 additions & 9 deletions google/cloud/spanner_v1/database_sessions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,10 @@ def get_session(self, transaction_type: TransactionType) -> Session:
:returns: a session for the given transaction type.
"""

use_multiplexed = self._use_multiplexed(transaction_type)

# TODO multiplexed: enable for read/write transactions
if use_multiplexed and transaction_type == TransactionType.READ_WRITE:
raise NotImplementedError(
f"Multiplexed sessions are not yet supported for {transaction_type} transactions."
)

session = (
self._get_multiplexed_session() if use_multiplexed else self._pool.get()
self._get_multiplexed_session()
if self._use_multiplexed(transaction_type)
else self._pool.get()
)

add_span_event(
Expand Down
84 changes: 37 additions & 47 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ def __init__(self, database, labels=None, database_role=None, is_multiplexed=Fal
self._database = database
self._session_id: Optional[str] = None

# TODO multiplexed - remove
self._transaction: Optional[Transaction] = None

if labels is None:
labels = {}

Expand Down Expand Up @@ -467,23 +464,18 @@ def batch(self):

return Batch(self)

def transaction(self):
def transaction(self) -> Transaction:
"""Create a transaction to perform a set of reads with shared staleness.

:rtype: :class:`~google.cloud.spanner_v1.transaction.Transaction`
:returns: a transaction bound to this session

:raises ValueError: if the session has not yet been created.
"""
if self._session_id is None:
raise ValueError("Session has not been created.")

# TODO multiplexed - remove
if self._transaction is not None:
self._transaction.rolled_back = True
self._transaction = None

txn = self._transaction = Transaction(self)
return txn
return Transaction(self)

def run_in_transaction(self, func, *args, **kw):
"""Perform a unit of work in a transaction, retrying on abort.
Expand Down Expand Up @@ -528,42 +520,43 @@ def run_in_transaction(self, func, *args, **kw):
)
isolation_level = kw.pop("isolation_level", None)

attempts = 0
database = self._database
log_commit_stats = database.log_commit_stats

observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.Session.run_in_transaction",
self,
observability_options=observability_options,
observability_options=getattr(database, "observability_options", None),
) as span, MetricsCapture():
attempts: int = 0

# If a transaction using a multiplexed session is retried after an aborted
# user operation, it should include the previous transaction ID in the
# transaction options used to begin the transaction. This allows the backend
# to recognize the transaction and increase the lock order for the new
# transaction that is created.
# See :attr:`~google.cloud.spanner_v1.types.TransactionOptions.ReadWrite.multiplexed_session_previous_transaction_id`
previous_transaction_id: Optional[bytes] = None

while True:
# TODO multiplexed - remove
if self._transaction is None:
txn = self.transaction()
txn.transaction_tag = transaction_tag
txn.exclude_txn_from_change_streams = (
exclude_txn_from_change_streams
txn = self.transaction()
txn.transaction_tag = transaction_tag
txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams
txn.isolation_level = isolation_level

if self.is_multiplexed:
txn._multiplexed_session_previous_transaction_id = (
previous_transaction_id
)
txn.isolation_level = isolation_level
else:
txn = self._transaction

span_attributes = dict()
attempts += 1
span_attributes = dict(attempt=attempts)

try:
attempts += 1
span_attributes["attempt"] = attempts
txn_id = getattr(txn, "_transaction_id", "") or ""
if txn_id:
span_attributes["transaction.id"] = txn_id

return_value = func(txn, *args, **kw)

# TODO multiplexed: store previous transaction ID.
except Aborted as exc:
# TODO multiplexed - remove
self._transaction = None

previous_transaction_id = txn._transaction_id
if span:
delay_seconds = _get_retry_delay(
exc.errors[0],
Expand All @@ -582,16 +575,15 @@ def run_in_transaction(self, func, *args, **kw):
exc, deadline, attempts, default_retry_delay=default_retry_delay
)
continue
except GoogleAPICallError:
# TODO multiplexed - remove
self._transaction = None

except GoogleAPICallError:
add_span_event(
span,
"User operation failed due to GoogleAPICallError, not retrying",
span_attributes,
)
raise

except Exception:
add_span_event(
span,
Expand All @@ -603,14 +595,13 @@ def run_in_transaction(self, func, *args, **kw):

try:
txn.commit(
return_commit_stats=self._database.log_commit_stats,
return_commit_stats=log_commit_stats,
request_options=commit_request_options,
max_commit_delay=max_commit_delay,
)
except Aborted as exc:
# TODO multiplexed - remove
self._transaction = None

except Aborted as exc:
previous_transaction_id = txn._transaction_id
if span:
delay_seconds = _get_retry_delay(
exc.errors[0],
Expand All @@ -621,26 +612,25 @@ def run_in_transaction(self, func, *args, **kw):
attributes.update(span_attributes)
add_span_event(
span,
"Transaction got aborted during commit, retrying afresh",
"Transaction was aborted during commit, retrying",
attributes,
)

_delay_until_retry(
exc, deadline, attempts, default_retry_delay=default_retry_delay
)
except GoogleAPICallError:
# TODO multiplexed - remove
self._transaction = None

except GoogleAPICallError:
add_span_event(
span,
"Transaction.commit failed due to GoogleAPICallError, not retrying",
span_attributes,
)
raise

else:
if self._database.log_commit_stats and txn.commit_stats:
self._database.logger.info(
if log_commit_stats and txn.commit_stats:
database.logger.info(
"CommitStats: {}".format(txn.commit_stats),
extra={"commit_stats": txn.commit_stats},
)
Expand Down
91 changes: 49 additions & 42 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _restart_on_unavailable(
item_buffer: List[PartialResultSet] = []

if transaction is not None:
transaction_selector = transaction._make_txn_selector()
transaction_selector = transaction._build_transaction_selector_pb()
elif transaction_selector is None:
raise InvalidArgument(
"Either transaction or transaction_selector should be set"
Expand Down Expand Up @@ -149,7 +149,7 @@ def _restart_on_unavailable(
) as span, MetricsCapture():
request.resume_token = resume_token
if transaction is not None:
transaction_selector = transaction._make_txn_selector()
transaction_selector = transaction._build_transaction_selector_pb()
request.transaction = transaction_selector
attempt += 1
iterator = method(
Expand Down Expand Up @@ -180,7 +180,7 @@ def _restart_on_unavailable(
) as span, MetricsCapture():
request.resume_token = resume_token
if transaction is not None:
transaction_selector = transaction._make_txn_selector()
transaction_selector = transaction._build_transaction_selector_pb()
attempt += 1
request.transaction = transaction_selector
iterator = method(
Expand Down Expand Up @@ -238,17 +238,6 @@ def __init__(self, session):
# threads, so we need to use a lock when updating the transaction.
self._lock: threading.Lock = threading.Lock()

def _make_txn_selector(self):
"""Helper for :meth:`read` / :meth:`execute_sql`.

Subclasses must override, returning an instance of
:class:`transaction_pb2.TransactionSelector`
appropriate for making ``read`` / ``execute_sql`` requests

:raises: NotImplementedError, always
"""
raise NotImplementedError

def begin(self) -> bytes:
"""Begins a transaction on the database.

Expand Down Expand Up @@ -732,7 +721,7 @@ def partition_read(
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
transaction = self._make_txn_selector()
transaction = self._build_transaction_selector_pb()
partition_options = PartitionOptions(
partition_size_bytes=partition_size_bytes, max_partitions=max_partitions
)
Expand Down Expand Up @@ -854,7 +843,7 @@ def partition_query(
metadata.append(
_metadata_with_leader_aware_routing(database._route_to_leader_enabled)
)
transaction = self._make_txn_selector()
transaction = self._build_transaction_selector_pb()
partition_options = PartitionOptions(
partition_size_bytes=partition_size_bytes, max_partitions=max_partitions
)
Expand Down Expand Up @@ -944,7 +933,7 @@ def _begin_transaction(self, mutation: Mutation = None) -> bytes:
def wrapped_method():
begin_transaction_request = BeginTransactionRequest(
session=session.name,
options=self._make_txn_selector().begin,
options=self._build_transaction_selector_pb().begin,
mutation_key=mutation,
)
begin_transaction_method = functools.partial(
Expand Down Expand Up @@ -983,6 +972,34 @@ def before_next_retry(nth_retry, delay_in_seconds):
self._update_for_transaction_pb(transaction_pb)
return self._transaction_id

def _build_transaction_options_pb(self) -> TransactionOptions:
"""Builds and returns the transaction options for this snapshot.

:rtype: :class:`transaction_pb2.TransactionOptions`
:returns: the transaction options for this snapshot.
"""
raise NotImplementedError

def _build_transaction_selector_pb(self) -> TransactionSelector:
"""Builds and returns a transaction selector for this snapshot.

:rtype: :class:`transaction_pb2.TransactionSelector`
:returns: a transaction selector for this snapshot.
"""

# Select a previously begun transaction.
if self._transaction_id is not None:
return TransactionSelector(id=self._transaction_id)

options = self._build_transaction_options_pb()

# Select a single-use transaction.
if not self._multi_use:
return TransactionSelector(single_use=options)

# Select a new, multi-use transaction.
return TransactionSelector(begin=options)

def _update_for_result_set_pb(
self, result_set_pb: Union[ResultSet, PartialResultSet]
) -> None:
Expand Down Expand Up @@ -1101,38 +1118,28 @@ def __init__(
self._multi_use = multi_use
self._transaction_id = transaction_id

# TODO multiplexed - refactor to base class
def _make_txn_selector(self):
"""Helper for :meth:`read`."""
if self._transaction_id is not None:
return TransactionSelector(id=self._transaction_id)
def _build_transaction_options_pb(self) -> TransactionOptions:
"""Builds and returns transaction options for this snapshot.

:rtype: :class:`transaction_pb2.TransactionOptions`
:returns: transaction options for this snapshot.
"""

read_only_pb_args = dict(return_read_timestamp=True)

if self._read_timestamp:
key = "read_timestamp"
value = self._read_timestamp
read_only_pb_args["read_timestamp"] = self._read_timestamp
elif self._min_read_timestamp:
key = "min_read_timestamp"
value = self._min_read_timestamp
read_only_pb_args["min_read_timestamp"] = self._min_read_timestamp
elif self._max_staleness:
key = "max_staleness"
value = self._max_staleness
read_only_pb_args["max_staleness"] = self._max_staleness
elif self._exact_staleness:
key = "exact_staleness"
value = self._exact_staleness
read_only_pb_args["exact_staleness"] = self._exact_staleness
else:
key = "strong"
value = True

options = TransactionOptions(
read_only=TransactionOptions.ReadOnly(
**{key: value, "return_read_timestamp": True}
)
)
read_only_pb_args["strong"] = True

if self._multi_use:
return TransactionSelector(begin=options)
else:
return TransactionSelector(single_use=options)
read_only_pb = TransactionOptions.ReadOnly(**read_only_pb_args)
return TransactionOptions(read_only=read_only_pb)

def _update_for_transaction_pb(self, transaction_pb: Transaction) -> None:
"""Updates the snapshot for the given transaction.
Expand Down
Loading
Loading