Skip to content

Commit

Permalink
Base for tests with retries on abort
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 18, 2024
1 parent 4ff0530 commit 4a37f4c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __create_transaction(
def Commit(self, request, context):
self._requests.append(request)
self.mock_spanner.pop_error(context)
tx = self.transactions[request.transaction_id]
tx = self.transactions.get(request.transaction_id, None)
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
del self.transactions[request.transaction_id]
Expand Down
5 changes: 3 additions & 2 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def rollback(self):
database._route_to_leader_enabled
)
)
all_metadata = database.metadata_with_request_id(database._next_nth_request, 1, metadata)
observability_options = getattr(database, "observability_options", None)
with trace_call(
f"CloudSpanner.{type(self).__name__}.rollback",
Expand All @@ -207,7 +208,7 @@ def rollback(self):
api.rollback,
session=self._session.name,
transaction_id=self._transaction_id,
metadata=metadata,
metadata=all_metadata,
)
_retry(
method,
Expand Down Expand Up @@ -288,7 +289,7 @@ def commit(
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
metadata=database.metadata_with_request_id(database._next_nth_request, 1, metadata),
)

def beforeNextRetry(nthRetry, delayInSeconds):
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/test_request_id_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
BatchCreateSessionsRequest,
ExecuteSqlRequest,
)
from google.api_core.exceptions import Aborted
from google.rpc import code_pb2
from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID


Expand Down Expand Up @@ -193,6 +195,45 @@ def select1():
]
assert got_stream_segments == want_stream_segments

def test_retries_on_abort(self):
counters = dict(aborted=0)
want_failed_attempts = 2

def select_in_txn(txn):
results = txn.execute_sql("select 1")
for row in results:
_ = row

if counters["aborted"] < want_failed_attempts:
counters["aborted"] += 1
raise Aborted(
"Thrown from ClientInterceptor for testing",
errors=[FauxCall(code_pb2.ABORTED)],
)

add_select1_result()
if not getattr(self.database, "_interceptors", None):
self.database._interceptors = MockServerTestBase._interceptors

self.database.run_in_transaction(select_in_txn)

def canonicalize_request_id_headers(self):
src = self.database._x_goog_request_id_interceptor
return src._stream_req_segments, src._unary_req_segments

class FauxCall:
def __init__(self, code, details="FauxCall"):
self._code = code
self._details = details

def initial_metadata(self):
return {}

def trailing_metadata(self):
return {}

def code(self):
return self._code

def details(self):
return self._details

0 comments on commit 4a37f4c

Please sign in to comment.