diff --git a/google/cloud/spanner_v1/testing/mock_spanner.py b/google/cloud/spanner_v1/testing/mock_spanner.py index 1f37ff2a03..e093fc7387 100644 --- a/google/cloud/spanner_v1/testing/mock_spanner.py +++ b/google/cloud/spanner_v1/testing/mock_spanner.py @@ -187,7 +187,7 @@ def __create_transaction( def Commit(self, request, context): self._requests.append(request) self.mock_spanner.pop_error(context) - tx = self.transactions[request.transaction_id] + tx = self.transactions.get(request.transaction_id, None) if tx is None: raise ValueError(f"Transaction not found: {request.transaction_id}") del self.transactions[request.transaction_id] diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index a8aef7f470..345f5ebf4c 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -197,6 +197,7 @@ def rollback(self): database._route_to_leader_enabled ) ) + all_metadata = database.metadata_with_request_id(database._next_nth_request, 1, metadata) observability_options = getattr(database, "observability_options", None) with trace_call( f"CloudSpanner.{type(self).__name__}.rollback", @@ -207,7 +208,7 @@ def rollback(self): api.rollback, session=self._session.name, transaction_id=self._transaction_id, - metadata=metadata, + metadata=all_metadata, ) _retry( method, @@ -288,7 +289,7 @@ def commit( method = functools.partial( api.commit, request=request, - metadata=metadata, + metadata=database.metadata_with_request_id(database._next_nth_request, 1, metadata), ) def beforeNextRetry(nthRetry, delayInSeconds): diff --git a/tests/unit/test_request_id_header.py b/tests/unit/test_request_id_header.py index df282f6356..c34afae1d9 100644 --- a/tests/unit/test_request_id_header.py +++ b/tests/unit/test_request_id_header.py @@ -24,6 +24,8 @@ BatchCreateSessionsRequest, ExecuteSqlRequest, ) +from google.api_core.exceptions import Aborted +from google.rpc import code_pb2 from google.cloud.spanner_v1.request_id_header import REQ_RAND_PROCESS_ID @@ -193,6 +195,45 @@ def select1(): ] assert got_stream_segments == want_stream_segments + def test_retries_on_abort(self): + counters = dict(aborted=0) + want_failed_attempts = 2 + + def select_in_txn(txn): + results = txn.execute_sql("select 1") + for row in results: + _ = row + + if counters["aborted"] < want_failed_attempts: + counters["aborted"] += 1 + raise Aborted( + "Thrown from ClientInterceptor for testing", + errors=[FauxCall(code_pb2.ABORTED)], + ) + + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + + self.database.run_in_transaction(select_in_txn) + def canonicalize_request_id_headers(self): src = self.database._x_goog_request_id_interceptor return src._stream_req_segments, src._unary_req_segments + +class FauxCall: + def __init__(self, code, details="FauxCall"): + self._code = code + self._details = details + + def initial_metadata(self): + return {} + + def trailing_metadata(self): + return {} + + def code(self): + return self._code + + def details(self): + return self._details