Skip to content

Commit c4362d0

Browse files
committed
fix: update retry strategy for mutation calls to handle aborted transactions
1 parent a6811af commit c4362d0

28 files changed

+1209
-266
lines changed
Binary file not shown.

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,8 @@ system_tests/local_test_setup
6262
# Make sure a generated file isn't accidentally committed.
6363
pylintrc
6464
pylintrc.test
65+
66+
# Ignore coverage files
67+
.coverage*
68+
# Ignore the myenv directory
69+
myenv/

google/cloud/spanner_dbapi/transaction_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
2222
from google.cloud.spanner_dbapi.exceptions import RetryAborted
23-
from google.cloud.spanner_v1.session import _get_retry_delay
23+
from google.cloud.spanner_v1._helpers import _get_retry_delay
2424

2525
if TYPE_CHECKING:
2626
from google.cloud.spanner_dbapi import Connection, Cursor

google/cloud/spanner_v1/_helpers.py

Lines changed: 118 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,23 @@
1919
import math
2020
import time
2121
import base64
22+
import threading
2223

2324
from google.protobuf.struct_pb2 import ListValue
2425
from google.protobuf.struct_pb2 import Value
2526
from google.protobuf.message import Message
2627
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper
2728

2829
from google.api_core import datetime_helpers
30+
from google.api_core.exceptions import Aborted
2931
from google.cloud._helpers import _date_from_iso8601_date
3032
from google.cloud.spanner_v1 import TypeCode
3133
from google.cloud.spanner_v1 import ExecuteSqlRequest
3234
from google.cloud.spanner_v1 import JsonObject
35+
from google.cloud.spanner_v1.request_id_header import with_request_id
36+
from google.rpc.error_details_pb2 import RetryInfo
37+
38+
import random
3339

3440
# Validation error messages
3541
NUMERIC_MAX_SCALE_ERR_MSG = (
@@ -464,13 +470,19 @@ def _retry(
464470
delay=2,
465471
allowed_exceptions=None,
466472
beforeNextRetry=None,
473+
deadline=None,
467474
):
468475
"""
469-
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
476+
Retry a specified function with different logic based on the type of exception raised.
477+
478+
If the exception is of type google.api_core.exceptions.Aborted,
479+
apply an alternate retry strategy that relies on the provided deadline value instead of a fixed number of retries.
480+
For all other exceptions, retry the function up to a specified number of times.
470481
471482
Args:
472483
func: The function to be retried.
473484
retry_count: The maximum number of times to retry the function.
485+
deadline: This will be used in case of Aborted transactions.
474486
delay: The delay in seconds between retries.
475487
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
476488
Passing allowed_exceptions as None will lead to retrying for all exceptions.
@@ -479,13 +491,21 @@ def _retry(
479491
The result of the function if it is successful, or raises the last exception if all retries fail.
480492
"""
481493
retries = 0
482-
while retries <= retry_count:
494+
while True:
483495
if retries > 0 and beforeNextRetry:
484496
beforeNextRetry(retries, delay)
485497

486498
try:
487499
return func()
488500
except Exception as exc:
501+
if isinstance(exc, Aborted) and deadline is not None:
502+
if (
503+
allowed_exceptions is not None
504+
and allowed_exceptions.get(exc.__class__) is not None
505+
):
506+
retries += 1
507+
_delay_until_retry(exc, deadline=deadline, attempts=retries)
508+
continue
489509
if (
490510
allowed_exceptions is None or exc.__class__ in allowed_exceptions
491511
) and retries < retry_count:
@@ -525,3 +545,99 @@ def _metadata_with_leader_aware_routing(value, **kw):
525545
List[Tuple[str, str]]: RPC metadata with leader aware routing header
526546
"""
527547
return ("x-goog-spanner-route-to-leader", str(value).lower())
548+
549+
550+
def _delay_until_retry(exc, deadline, attempts):
551+
"""Helper for :meth:`Session.run_in_transaction`.
552+
553+
Detect retryable abort, and impose server-supplied delay.
554+
555+
:type exc: :class:`google.api_core.exceptions.Aborted`
556+
:param exc: exception for aborted transaction
557+
558+
:type deadline: float
559+
:param deadline: maximum timestamp to continue retrying the transaction.
560+
561+
:type attempts: int
562+
:param attempts: number of call retries
563+
"""
564+
565+
cause = exc.errors[0]
566+
now = time.time()
567+
if now >= deadline:
568+
raise
569+
570+
delay = _get_retry_delay(cause, attempts)
571+
if delay is not None:
572+
if now + delay > deadline:
573+
raise
574+
575+
time.sleep(delay)
576+
577+
578+
def _get_retry_delay(cause, attempts):
579+
"""Helper for :func:`_delay_until_retry`.
580+
581+
:type exc: :class:`grpc.Call`
582+
:param exc: exception for aborted transaction
583+
584+
:rtype: float
585+
:returns: seconds to wait before retrying the transaction.
586+
587+
:type attempts: int
588+
:param attempts: number of call retries
589+
"""
590+
if hasattr(cause, "trailing_metadata"):
591+
metadata = dict(cause.trailing_metadata())
592+
else:
593+
metadata = {}
594+
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
595+
if retry_info_pb is not None:
596+
retry_info = RetryInfo()
597+
retry_info.ParseFromString(retry_info_pb)
598+
nanos = retry_info.retry_delay.nanos
599+
return retry_info.retry_delay.seconds + nanos / 1.0e9
600+
601+
return 2**attempts + random.random()
602+
603+
604+
class AtomicCounter:
605+
def __init__(self, start_value=0):
606+
self.__lock = threading.Lock()
607+
self.__value = start_value
608+
609+
@property
610+
def value(self):
611+
with self.__lock:
612+
return self.__value
613+
614+
def increment(self, n=1):
615+
with self.__lock:
616+
self.__value += n
617+
return self.__value
618+
619+
def __iadd__(self, n):
620+
"""
621+
Defines the inplace += operator result.
622+
"""
623+
with self.__lock:
624+
self.__value += n
625+
return self
626+
627+
def __add__(self, n):
628+
"""
629+
Defines the result of invoking: value = AtomicCounter + addable
630+
"""
631+
with self.__lock:
632+
n += self.__value
633+
return n
634+
635+
def __radd__(self, n):
636+
"""
637+
Defines the result of invoking: value = addable + AtomicCounter
638+
"""
639+
return self.__add__(n)
640+
641+
642+
def _metadata_with_request_id(*args, **kwargs):
643+
return with_request_id(*args, **kwargs)

google/cloud/spanner_v1/_opentelemetry_tracing.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,11 @@ def get_tracer(tracer_provider=None):
5656

5757

5858
@contextmanager
59-
def trace_call(name, session, extra_attributes=None, observability_options=None):
59+
def trace_call(name, session=None, extra_attributes=None, observability_options=None):
6060
if session:
6161
session._last_use_time = datetime.now()
6262

63-
if not HAS_OPENTELEMETRY_INSTALLED or not session:
63+
if not (HAS_OPENTELEMETRY_INSTALLED and name):
6464
# Empty context manager. Users will have to check if the generated value is None or a span
6565
yield None
6666
return
@@ -72,20 +72,24 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
7272
# on by default.
7373
enable_extended_tracing = True
7474

75+
db_name = ""
76+
if session and getattr(session, "_database", None):
77+
db_name = session._database.name
78+
7579
if isinstance(observability_options, dict): # Avoid false positives with mock.Mock
7680
tracer_provider = observability_options.get("tracer_provider", None)
7781
enable_extended_tracing = observability_options.get(
7882
"enable_extended_tracing", enable_extended_tracing
7983
)
84+
db_name = observability_options.get("db_name", db_name)
8085

8186
tracer = get_tracer(tracer_provider)
8287

8388
# Set base attributes that we know for every trace created
84-
db = session._database
8589
attributes = {
8690
"db.type": "spanner",
8791
"db.url": SpannerClient.DEFAULT_ENDPOINT,
88-
"db.instance": "" if not db else db.name,
92+
"db.instance": db_name,
8993
"net.host.name": SpannerClient.DEFAULT_ENDPOINT,
9094
OTEL_SCOPE_NAME: TRACER_NAME,
9195
OTEL_SCOPE_VERSION: TRACER_VERSION,

google/cloud/spanner_v1/batch.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
from google.cloud.spanner_v1._helpers import _retry
3232
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
3333
from google.api_core.exceptions import InternalServerError
34+
from google.api_core.exceptions import Aborted
35+
import time
36+
37+
DEFAULT_RETRY_TIMEOUT_SECS = 30
3438

3539

3640
class _BatchBase(_SessionWrapper):
@@ -70,6 +74,8 @@ def insert(self, table, columns, values):
7074
:param values: Values to be modified.
7175
"""
7276
self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values)))
77+
# TODO: Decide if we should add a span event per mutation:
78+
# https://github.com/googleapis/python-spanner/issues/1269
7379

7480
def update(self, table, columns, values):
7581
"""Update one or more existing table rows.
@@ -84,6 +90,8 @@ def update(self, table, columns, values):
8490
:param values: Values to be modified.
8591
"""
8692
self._mutations.append(Mutation(update=_make_write_pb(table, columns, values)))
93+
# TODO: Decide if we should add a span event per mutation:
94+
# https://github.com/googleapis/python-spanner/issues/1269
8795

8896
def insert_or_update(self, table, columns, values):
8997
"""Insert/update one or more table rows.
@@ -100,6 +108,8 @@ def insert_or_update(self, table, columns, values):
100108
self._mutations.append(
101109
Mutation(insert_or_update=_make_write_pb(table, columns, values))
102110
)
111+
# TODO: Decide if we should add a span event per mutation:
112+
# https://github.com/googleapis/python-spanner/issues/1269
103113

104114
def replace(self, table, columns, values):
105115
"""Replace one or more table rows.
@@ -114,6 +124,8 @@ def replace(self, table, columns, values):
114124
:param values: Values to be modified.
115125
"""
116126
self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values)))
127+
# TODO: Decide if we should add a span event per mutation:
128+
# https://github.com/googleapis/python-spanner/issues/1269
117129

118130
def delete(self, table, keyset):
119131
"""Delete one or more table rows.
@@ -126,6 +138,8 @@ def delete(self, table, keyset):
126138
"""
127139
delete = Mutation.Delete(table=table, key_set=keyset._to_pb())
128140
self._mutations.append(Mutation(delete=delete))
141+
# TODO: Decide if we should add a span event per mutation:
142+
# https://github.com/googleapis/python-spanner/issues/1269
129143

130144

131145
class Batch(_BatchBase):
@@ -152,6 +166,7 @@ def commit(
152166
request_options=None,
153167
max_commit_delay=None,
154168
exclude_txn_from_change_streams=False,
169+
**kwargs,
155170
):
156171
"""Commit mutations to the database.
157172
@@ -207,7 +222,7 @@ def commit(
207222
)
208223
observability_options = getattr(database, "observability_options", None)
209224
with trace_call(
210-
"CloudSpanner.Commit",
225+
f"CloudSpanner.{type(self).__name__}.commit",
211226
self._session,
212227
trace_attributes,
213228
observability_options=observability_options,
@@ -217,9 +232,16 @@ def commit(
217232
request=request,
218233
metadata=metadata,
219234
)
235+
deadline = time.time() + kwargs.get(
236+
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
237+
)
220238
response = _retry(
221239
method,
222-
allowed_exceptions={InternalServerError: _check_rst_stream_error},
240+
allowed_exceptions={
241+
InternalServerError: _check_rst_stream_error,
242+
Aborted: no_op_handler,
243+
},
244+
deadline=deadline,
223245
)
224246
self.committed = response.commit_timestamp
225247
self.commit_stats = response.commit_stats
@@ -283,7 +305,9 @@ def group(self):
283305
self._mutation_groups.append(mutation_group)
284306
return MutationGroup(self._session, mutation_group.mutations)
285307

286-
def batch_write(self, request_options=None, exclude_txn_from_change_streams=False):
308+
def batch_write(
309+
self, request_options=None, exclude_txn_from_change_streams=False, **kwargs
310+
):
287311
"""Executes batch_write.
288312
289313
:type request_options:
@@ -336,9 +360,16 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
336360
request=request,
337361
metadata=metadata,
338362
)
363+
deadline = time.time() + kwargs.get(
364+
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
365+
)
339366
response = _retry(
340367
method,
341-
allowed_exceptions={InternalServerError: _check_rst_stream_error},
368+
allowed_exceptions={
369+
InternalServerError: _check_rst_stream_error,
370+
Aborted: no_op_handler,
371+
},
372+
deadline=deadline,
342373
)
343374
self.committed = True
344375
return response
@@ -362,3 +393,8 @@ def _make_write_pb(table, columns, values):
362393
return Mutation.Write(
363394
table=table, columns=columns, values=_make_list_value_pbs(values)
364395
)
396+
397+
398+
def no_op_handler(exc):
399+
# No-op (does nothing)
400+
pass

0 commit comments

Comments
 (0)