Skip to content

Commit

Permalink
More plumbing for Transaction and Database
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 18, 2024
1 parent 4a37f4c commit 529333a
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 38 deletions.
33 changes: 21 additions & 12 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,17 +728,20 @@ def execute_partitioned_dml(
_metadata_with_leader_aware_routing(self._route_to_leader_enabled)
)

nth_request = getattr(self, "_next_nth_request", 0)
# Attempt will be incremented inside _restart_on_unavailable.
attempt = AtomicCounter(1)
begin_txn_nth_request = self._next_nth_request
begin_txn_attempt = AtomicCounter(1)
partial_nth_request = self._next_nth_request
partial_attempt = AtomicCounter(0)

def execute_pdml():
with SessionCheckout(self._pool) as session:
all_metadata = self.metadata_with_request_id(
nth_request, attempt.value, metadata
)
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=all_metadata
session=session.name,
options=txn_options,
metadata=self.metadata_with_request_id(
begin_txn_nth_request, begin_txn_attempt.value, metadata
),
)

txn_selector = TransactionSelector(id=txn.id)
Expand All @@ -751,18 +754,24 @@ def execute_pdml():
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
)

def wrapped_method(*args, **kwargs):
partial_attempt.increment()
method = functools.partial(
api.execute_streaming_sql,
metadata=self.metadata_with_request_id(
partial_nth_request, partial_attempt.value, metadata
),
)
return method(*args, **kwargs)

iterator = _restart_on_unavailable(
method=method,
method=wrapped_method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
attempt=attempt,
attempt=begin_txn_attempt,
)

result_set = StreamedResultSet(iterator)
Expand Down
5 changes: 3 additions & 2 deletions google/cloud/spanner_v1/request_id_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ def generate_rand_uint64():

def with_request_id(client_id, channel_id, nth_request, attempt, other_metadata=[]):
req_id = f"{REQ_ID_VERSION}.{REQ_RAND_PROCESS_ID}.{client_id}.{channel_id}.{nth_request}.{attempt}"
other_metadata.append((REQ_ID_HEADER_KEY, req_id))
return other_metadata
all_metadata = other_metadata.copy()
all_metadata.append((REQ_ID_HEADER_KEY, req_id))
return all_metadata
71 changes: 48 additions & 23 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import TransactionSelector
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud.spanner_v1._helpers import AtomicCounter
from google.cloud.spanner_v1.snapshot import _SnapshotBase
from google.cloud.spanner_v1.batch import _BatchBase
from google.cloud.spanner_v1._opentelemetry_tracing import add_span_event, trace_call
Expand Down Expand Up @@ -197,21 +198,29 @@ def rollback(self):
database._route_to_leader_enabled
)
)
all_metadata = database.metadata_with_request_id(database._next_nth_request, 1, metadata)

observability_options = getattr(database, "observability_options", None)
with trace_call(
f"CloudSpanner.{type(self).__name__}.rollback",
self._session,
observability_options=observability_options,
):
method = functools.partial(
api.rollback,
session=self._session.name,
transaction_id=self._transaction_id,
metadata=all_metadata,
)
attempt = AtomicCounter(0)
nth_request = database._next_nth_request

def wrapped_method(*args, **kwargs):
attempt.increment()
method = functools.partial(
api.rollback,
session=self._session.name,
transaction_id=self._transaction_id,
metadata=database.metadata_with_request_id(
nth_request, attempt.value, metadata
),
)

_retry(
method,
wrapped_method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.rolled_back = True
Expand Down Expand Up @@ -286,11 +295,19 @@ def commit(
) as span:
add_span_event(span, "Starting Commit")

method = functools.partial(
api.commit,
request=request,
metadata=database.metadata_with_request_id(database._next_nth_request, 1, metadata),
)
attempt = AtomicCounter(0)
nth_request = database._next_nth_request

def wrapped_method(*args, **kwargs):
attempt.increment()
method = functools.partial(
api.commit,
request=request,
metadata=database.metadata_with_request_id(
nth_request, attempt.value, metadata
),
)
return method(*args, **kwargs)

def beforeNextRetry(nthRetry, delayInSeconds):
add_span_event(
Expand All @@ -300,7 +317,7 @@ def beforeNextRetry(nthRetry, delayInSeconds):
)

response = _retry(
method,
wrapped_method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
beforeNextRetry=beforeNextRetry,
)
Expand Down Expand Up @@ -434,19 +451,27 @@ def execute_update(
request_options=request_options,
)

method = functools.partial(
api.execute_sql,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
nth_request = database._next_nth_request
attempt = AtomicCounter(0)

def wrapped_method(*args, **kwargs):
attempt.increment()
method = functools.partial(
api.execute_sql,
request=request,
metadata=database.metadata_with_request_id(
nth_request, attempt.value, metadata
),
retry=retry,
timeout=timeout,
)
return method(*args, **kwargs)

if self._transaction_id is None:
# lock is added to handle the inline begin for first rpc
with self._lock:
response = self._execute_request(
method,
wrapped_method,
request,
f"CloudSpanner.{type(self).__name__}.execute_update",
self._session,
Expand All @@ -463,7 +488,7 @@ def execute_update(
self._transaction_id = response.metadata.transaction.id
else:
response = self._execute_request(
method,
wrapped_method,
request,
f"CloudSpanner.{type(self).__name__}.execute_update",
self._session,
Expand Down
150 changes: 149 additions & 1 deletion tests/unit/test_request_id_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from google.cloud.spanner_v1.testing.interceptors import XGoogRequestIDHeaderInterceptor
from google.cloud.spanner_v1 import (
BatchCreateSessionsRequest,
BeginTransactionRequest,
ExecuteSqlRequest,
)
from google.api_core.exceptions import Aborted
Expand Down Expand Up @@ -195,7 +196,7 @@ def select1():
]
assert got_stream_segments == want_stream_segments

def test_retries_on_abort(self):
def test_database_run_in_transaction_retries_on_abort(self):
counters = dict(aborted=0)
want_failed_attempts = 2

Expand All @@ -217,10 +218,157 @@ def select_in_txn(txn):

self.database.run_in_transaction(select_in_txn)

def test_database_execute_partitioned_dml_request_id(self):
add_select1_result()
if not getattr(self.database, "_interceptors", None):
self.database._interceptors = MockServerTestBase._interceptors
_ = self.database.execute_partitioned_dml("select 1")

requests = self.spanner_service.requests
self.assertEqual(3, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))

# Now ensure monotonicity of the received request-id segments.
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()
want_unary_segments = [
(
"/google.spanner.v1.Spanner/BatchCreateSessions",
(1, REQ_RAND_PROCESS_ID, 1, 1, 1, 1),
),
(
"/google.spanner.v1.Spanner/BeginTransaction",
(1, REQ_RAND_PROCESS_ID, 1, 1, 2, 1),
),
]
want_stream_segments = [
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, 1, 1, 3, 1),
)
]

assert got_unary_segments == want_unary_segments
assert got_stream_segments == want_stream_segments

def test_snapshot_read(self):
add_select1_result()
if not getattr(self.database, "_interceptors", None):
self.database._interceptors = MockServerTestBase._interceptors
with self.database.snapshot() as snapshot:
results = snapshot.read("select 1")
result_list = []
for row in results:
result_list.append(row)
self.assertEqual(1, row[0])
self.assertEqual(1, len(result_list))

requests = self.spanner_service.requests
self.assertEqual(2, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))

requests = self.spanner_service.requests
self.assertEqual(n * 2, len(requests), msg=requests)

client_id = self.database._nth_client_id
channel_id = self.database._channel_id
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()

want_unary_segments = [
(
"/google.spanner.v1.Spanner/BatchCreateSessions",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1),
),
]
assert got_unary_segments == want_unary_segments

want_stream_segments = [
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1),
),
]
assert got_stream_segments == want_stream_segments

def canonicalize_request_id_headers(self):
src = self.database._x_goog_request_id_interceptor
return src._stream_req_segments, src._unary_req_segments


class FauxCall:
def __init__(self, code, details="FauxCall"):
self._code = code
Expand Down

0 comments on commit 529333a

Please sign in to comment.