Skip to content

fix: update retry strategy for mutation calls to handle aborted transactions #1270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,8 @@ system_tests/local_test_setup
# Make sure a generated file isn't accidentally committed.
pylintrc
pylintrc.test

# Ignore coverage files
.coverage*
# Ignore the myenv directory
myenv/
2 changes: 1 addition & 1 deletion google/cloud/spanner_dbapi/transaction_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
from google.cloud.spanner_dbapi.exceptions import RetryAborted
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1._helpers import _get_retry_delay

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection, Cursor
Expand Down
120 changes: 118 additions & 2 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@
import math
import time
import base64
import threading

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
from google.protobuf.message import Message
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

from google.api_core import datetime_helpers
from google.api_core.exceptions import Aborted
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.rpc.error_details_pb2 import RetryInfo

import random

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -464,13 +470,19 @@ def _retry(
delay=2,
allowed_exceptions=None,
beforeNextRetry=None,
deadline=None,
):
"""
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
Retry a specified function with different logic based on the type of exception raised.

If the exception is of type google.api_core.exceptions.Aborted,
apply an alternate retry strategy that relies on the provided deadline value instead of a fixed number of retries.
For all other exceptions, retry the function up to a specified number of times.

Args:
func: The function to be retried.
retry_count: The maximum number of times to retry the function.
deadline: This will be used in case of Aborted transactions.
delay: The delay in seconds between retries.
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
Passing allowed_exceptions as None will lead to retrying for all exceptions.
Expand All @@ -479,13 +491,21 @@ def _retry(
The result of the function if it is successful, or raises the last exception if all retries fail.
"""
retries = 0
while retries <= retry_count:
while True:
if retries > 0 and beforeNextRetry:
beforeNextRetry(retries, delay)

try:
return func()
except Exception as exc:
if isinstance(exc, Aborted) and deadline is not None:
if (
allowed_exceptions is not None
and allowed_exceptions.get(exc.__class__) is not None
):
retries += 1
_delay_until_retry(exc, deadline=deadline, attempts=retries)
continue
if (
allowed_exceptions is None or exc.__class__ in allowed_exceptions
) and retries < retry_count:
Expand Down Expand Up @@ -525,3 +545,99 @@ def _metadata_with_leader_aware_routing(value, **kw):
List[Tuple[str, str]]: RPC metadata with leader aware routing header
"""
return ("x-goog-spanner-route-to-leader", str(value).lower())


def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.

Detect retryable abort, and impose server-supplied delay.

:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction

:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""

cause = exc.errors[0]
now = time.time()
if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.

:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction

:rtype: float
:returns: seconds to wait before retrying the transaction.

:type attempts: int
:param attempts: number of call retries
"""
if hasattr(cause, "trailing_metadata"):
metadata = dict(cause.trailing_metadata())
else:
metadata = {}
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()


class AtomicCounter:
def __init__(self, start_value=0):
self.__lock = threading.Lock()
self.__value = start_value

@property
def value(self):
with self.__lock:
return self.__value

def increment(self, n=1):
with self.__lock:
self.__value += n
return self.__value

def __iadd__(self, n):
"""
Defines the inplace += operator result.
"""
with self.__lock:
self.__value += n
return self

def __add__(self, n):
"""
Defines the result of invoking: value = AtomicCounter + addable
"""
with self.__lock:
n += self.__value
return n

def __radd__(self, n):
"""
Defines the result of invoking: value = addable + AtomicCounter
"""
return self.__add__(n)


def _metadata_with_request_id(*args, **kwargs):
return with_request_id(*args, **kwargs)
12 changes: 8 additions & 4 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ def get_tracer(tracer_provider=None):


@contextmanager
def trace_call(name, session, extra_attributes=None, observability_options=None):
def trace_call(name, session=None, extra_attributes=None, observability_options=None):
if session:
session._last_use_time = datetime.now()

if not HAS_OPENTELEMETRY_INSTALLED or not session:
if not (HAS_OPENTELEMETRY_INSTALLED and name):
# Empty context manager. Users will have to check if the generated value is None or a span
yield None
return
Expand All @@ -72,20 +72,24 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
# on by default.
enable_extended_tracing = True

db_name = ""
if session and getattr(session, "_database", None):
db_name = session._database.name

if isinstance(observability_options, dict): # Avoid false positives with mock.Mock
tracer_provider = observability_options.get("tracer_provider", None)
enable_extended_tracing = observability_options.get(
"enable_extended_tracing", enable_extended_tracing
)
db_name = observability_options.get("db_name", db_name)

tracer = get_tracer(tracer_provider)

# Set base attributes that we know for every trace created
db = session._database
attributes = {
"db.type": "spanner",
"db.url": SpannerClient.DEFAULT_ENDPOINT,
"db.instance": "" if not db else db.name,
"db.instance": db_name,
"net.host.name": SpannerClient.DEFAULT_ENDPOINT,
OTEL_SCOPE_NAME: TRACER_NAME,
OTEL_SCOPE_VERSION: TRACER_VERSION,
Expand Down
44 changes: 40 additions & 4 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError
from google.api_core.exceptions import Aborted
import time

DEFAULT_RETRY_TIMEOUT_SECS = 30


class _BatchBase(_SessionWrapper):
Expand Down Expand Up @@ -70,6 +74,8 @@ def insert(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def update(self, table, columns, values):
"""Update one or more existing table rows.
Expand All @@ -84,6 +90,8 @@ def update(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(update=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def insert_or_update(self, table, columns, values):
"""Insert/update one or more table rows.
Expand All @@ -100,6 +108,8 @@ def insert_or_update(self, table, columns, values):
self._mutations.append(
Mutation(insert_or_update=_make_write_pb(table, columns, values))
)
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def replace(self, table, columns, values):
"""Replace one or more table rows.
Expand All @@ -114,6 +124,8 @@ def replace(self, table, columns, values):
:param values: Values to be modified.
"""
self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values)))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269

def delete(self, table, keyset):
"""Delete one or more table rows.
Expand All @@ -126,6 +138,8 @@ def delete(self, table, keyset):
"""
delete = Mutation.Delete(table=table, key_set=keyset._to_pb())
self._mutations.append(Mutation(delete=delete))
# TODO: Decide if we should add a span event per mutation:
# https://github.com/googleapis/python-spanner/issues/1269


class Batch(_BatchBase):
Expand All @@ -152,6 +166,7 @@ def commit(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kwargs,
):
"""Commit mutations to the database.

Expand Down Expand Up @@ -207,7 +222,7 @@ def commit(
)
observability_options = getattr(database, "observability_options", None)
with trace_call(
"CloudSpanner.Commit",
f"CloudSpanner.{type(self).__name__}.commit",
self._session,
trace_attributes,
observability_options=observability_options,
Expand All @@ -217,9 +232,16 @@ def commit(
request=request,
metadata=metadata,
)
deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
allowed_exceptions={
InternalServerError: _check_rst_stream_error,
Aborted: no_op_handler,
},
deadline=deadline,
)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
Expand Down Expand Up @@ -283,7 +305,9 @@ def group(self):
self._mutation_groups.append(mutation_group)
return MutationGroup(self._session, mutation_group.mutations)

def batch_write(self, request_options=None, exclude_txn_from_change_streams=False):
def batch_write(
self, request_options=None, exclude_txn_from_change_streams=False, **kwargs
):
"""Executes batch_write.

:type request_options:
Expand Down Expand Up @@ -336,9 +360,16 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
request=request,
metadata=metadata,
)
deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
allowed_exceptions={
InternalServerError: _check_rst_stream_error,
Aborted: no_op_handler,
},
deadline=deadline,
)
self.committed = True
return response
Expand All @@ -362,3 +393,8 @@ def _make_write_pb(table, columns, values):
return Mutation.Write(
table=table, columns=columns, values=_make_list_value_pbs(values)
)


def no_op_handler(exc):
# No-op (does nothing)
pass
Loading