diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f1e5374f4b..a33ca8b0fb 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -728,17 +728,20 @@ def execute_partitioned_dml( _metadata_with_leader_aware_routing(self._route_to_leader_enabled) ) - nth_request = getattr(self, "_next_nth_request", 0) # Attempt will be incremented inside _restart_on_unavailable. - attempt = AtomicCounter(1) + begin_txn_nth_request = self._next_nth_request + begin_txn_attempt = AtomicCounter(1) + partial_nth_request = self._next_nth_request + partial_attempt = AtomicCounter(0) def execute_pdml(): with SessionCheckout(self._pool) as session: - all_metadata = self.metadata_with_request_id( - nth_request, attempt.value, metadata - ) txn = api.begin_transaction( - session=session.name, options=txn_options, metadata=all_metadata + session=session.name, + options=txn_options, + metadata=self.metadata_with_request_id( + begin_txn_nth_request, begin_txn_attempt.value, metadata + ), ) txn_selector = TransactionSelector(id=txn.id) @@ -751,18 +754,24 @@ def execute_pdml(): query_options=query_options, request_options=request_options, ) - method = functools.partial( - api.execute_streaming_sql, - metadata=metadata, - ) + + def wrapped_method(*args, **kwargs): + partial_attempt.increment() + method = functools.partial( + api.execute_streaming_sql, + metadata=self.metadata_with_request_id( + partial_nth_request, partial_attempt.value, metadata + ), + ) + return method(*args, **kwargs) iterator = _restart_on_unavailable( - method=method, + method=wrapped_method, trace_name="CloudSpanner.ExecuteStreamingSql", request=request, transaction_selector=txn_selector, observability_options=self.observability_options, - attempt=attempt, + attempt=begin_txn_attempt, ) result_set = StreamedResultSet(iterator) diff --git a/google/cloud/spanner_v1/request_id_header.py b/google/cloud/spanner_v1/request_id_header.py index 6e4e0e3fd8..c2c0b5464e 100644 --- a/google/cloud/spanner_v1/request_id_header.py +++ b/google/cloud/spanner_v1/request_id_header.py @@ -38,5 +38,6 @@ def generate_rand_uint64(): def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]): req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}" - other_metadata.append((REQ_ID_HEADER_KEY, req_id)) - return other_metadata + all_metadata = other_metadata.copy() + all_metadata.append((REQ_ID_HEADER_KEY, req_id)) + return all_metadata diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 345f5ebf4c..faff9e8f6f 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -30,6 +30,7 @@ from google.cloud.spanner_v1 import ExecuteSqlRequest from google.cloud.spanner_v1 import TransactionSelector from google.cloud.spanner_v1 import TransactionOptions +from google.cloud.spanner_v1._helpers import AtomicCounter from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase from google.cloud.spanner_v1._opentelemetry_tracing import add_span_event, trace_call @@ -197,21 +198,29 @@ def rollback(self): database._route_to_leader_enabled ) ) - all_metadata = database.metadata_with_request_id(database._next_nth_request, 1, metadata) + observability_options = getattr(database, "observability_options", None) with trace_call( f"CloudSpanner.{type(self).__name__}.rollback", self._session, observability_options=observability_options, ): - method = functools.partial( - api.rollback, - session=self._session.name, - transaction_id=self._transaction_id, - metadata=all_metadata, - ) + attempt = AtomicCounter(0) + nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + attempt.increment() + method = functools.partial( + api.rollback, + session=self._session.name, + transaction_id=self._transaction_id, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + ) + _retry( - method, + wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, ) self.rolled_back = True @@ -286,11 +295,19 @@ def commit( ) as span: add_span_event(span, "Starting Commit") - method = functools.partial( - api.commit, - request=request, - metadata=database.metadata_with_request_id(database._next_nth_request, 1, metadata), - ) + attempt = AtomicCounter(0) + nth_request = database._next_nth_request + + def wrapped_method(*args, **kwargs): + attempt.increment() + method = functools.partial( + api.commit, + request=request, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + ) + return method(*args, **kwargs) def beforeNextRetry(nthRetry, delayInSeconds): add_span_event( @@ -300,7 +317,7 @@ def beforeNextRetry(nthRetry, delayInSeconds): ) response = _retry( - method, + wrapped_method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, beforeNextRetry=beforeNextRetry, ) @@ -434,19 +451,27 @@ def execute_update( request_options=request_options, ) - method = functools.partial( - api.execute_sql, - request=request, - metadata=metadata, - retry=retry, - timeout=timeout, - ) + nth_request = database._next_nth_request + attempt = AtomicCounter(0) + + def wrapped_method(*args, **kwargs): + attempt.increment() + method = functools.partial( + api.execute_sql, + request=request, + metadata=database.metadata_with_request_id( + nth_request, attempt.value, metadata + ), + retry=retry, + timeout=timeout, + ) + return method(*args, **kwargs) if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: response = self._execute_request( - method, + wrapped_method, request, f"CloudSpanner.{type(self).__name__}.execute_update", self._session, @@ -463,7 +488,7 @@ def execute_update( self._transaction_id = response.metadata.transaction.id else: response = self._execute_request( - method, + wrapped_method, request, f"CloudSpanner.{type(self).__name__}.execute_update", self._session, diff --git a/tests/unit/test_request_id_header.py b/tests/unit/test_request_id_header.py index c34afae1d9..35c7de7410 100644 --- a/tests/unit/test_request_id_header.py +++ b/tests/unit/test_request_id_header.py @@ -22,6 +22,7 @@ from google.cloud.spanner_v1.testing.interceptors import XGoogRequestIDHeaderInterceptor from google.cloud.spanner_v1 import ( BatchCreateSessionsRequest, + BeginTransactionRequest, ExecuteSqlRequest, ) from google.api_core.exceptions import Aborted @@ -195,7 +196,7 @@ def select1(): ] assert got_stream_segments == want_stream_segments - def test_retries_on_abort(self): + def test_database_run_in_transaction_retries_on_abort(self): counters = dict(aborted=0) want_failed_attempts = 2 @@ -217,10 +218,157 @@ def select_in_txn(txn): self.database.run_in_transaction(select_in_txn) + def test_database_execute_partitioned_dml_request_id(self): + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + _ = self.database.execute_partitioned_dml("select 1") + + requests = self.spanner_service.requests + self.assertEqual(3, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], BeginTransactionRequest)) + self.assertTrue(isinstance(requests[2], ExecuteSqlRequest)) + + # Now ensure monotonicity of the received request-id segments. + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, 1, 1, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/BeginTransaction", + (1, REQ_RAND_PROCESS_ID, 1, 1, 2, 1), + ), + ] + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, 1, 1, 3, 1), + ) + ] + + assert got_unary_segments == want_unary_segments + assert got_stream_segments == want_stream_segments + + def test_snapshot_read(self): + add_select1_result() + if not getattr(self.database, "_interceptors", None): + self.database._interceptors = MockServerTestBase._interceptors + with self.database.snapshot() as snapshot: + results = snapshot.read("select 1") + result_list = [] + for row in results: + result_list.append(row) + self.assertEqual(1, row[0]) + self.assertEqual(1, len(result_list)) + + requests = self.spanner_service.requests + self.assertEqual(2, len(requests), msg=requests) + self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest)) + self.assertTrue(isinstance(requests[1], ExecuteSqlRequest)) + + requests = self.spanner_service.requests + self.assertEqual(n * 2, len(requests), msg=requests) + + client_id = self.database._nth_client_id + channel_id = self.database._channel_id + got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers() + + want_unary_segments = [ + ( + "/google.spanner.v1.Spanner/BatchCreateSessions", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1), + ), + ( + "/google.spanner.v1.Spanner/GetSession", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1), + ), + ] + assert got_unary_segments == want_unary_segments + + want_stream_segments = [ + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1), + ), + ( + "/google.spanner.v1.Spanner/ExecuteStreamingSql", + (1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1), + ), + ] + assert got_stream_segments == want_stream_segments + def canonicalize_request_id_headers(self): src = self.database._x_goog_request_id_interceptor return src._stream_req_segments, src._unary_req_segments + class FauxCall: def __init__(self, code, details="FauxCall"): self._code = code