Skip to content

Commit

Permalink
Wrap Snapshot methods with x-goog-request-id metadata injector
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 14, 2024
1 parent 47da5e4 commit bcf868b
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 44 deletions.
3 changes: 2 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,8 @@ def execute_partitioned_dml(
)

nth_request = getattr(self, "_next_nth_request", 0)
attempt = AtomicCounter(1) # It'll be incremented inside _restart_on_unavailable
# Attempt will be incremented inside _restart_on_unavailable.
attempt = AtomicCounter(1)

def execute_pdml():
with SessionCheckout(self._pool) as session:
Expand Down
149 changes: 108 additions & 41 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
_merge_query_options,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_metadata_with_request_id,
_retry,
_check_rst_stream_error,
_SessionWrapper,
AtomicCounter,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.streamed import StreamedResultSet
Expand Down Expand Up @@ -320,13 +322,26 @@ def read(
data_boost_enabled=data_boost_enabled,
directed_read_options=directed_read_options,
)
restart = functools.partial(
api.streaming_read,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)

nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_restart(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
)

restart = functools.partial(
api.streaming_read,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)
return restart(*args, **kwargs)

trace_attributes = {"table_id": table, "columns": columns}
observability_options = getattr(database, "observability_options", None)
Expand All @@ -335,7 +350,7 @@ def read(
# lock is added to handle the inline begin for first rpc
with self._lock:
iterator = _restart_on_unavailable(
restart,
wrapped_restart,
request,
"CloudSpanner.ReadOnlyTransaction",
self._session,
Expand All @@ -357,7 +372,7 @@ def read(
)
else:
iterator = _restart_on_unavailable(
restart,
wrapped_restart,
request,
"CloudSpanner.ReadOnlyTransaction",
self._session,
Expand Down Expand Up @@ -536,13 +551,27 @@ def execute_sql(
data_boost_enabled=data_boost_enabled,
directed_read_options=directed_read_options,
)
restart = functools.partial(
api.execute_streaming_sql,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)

nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_restart(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
)

restart = functools.partial(
api.execute_streaming_sql,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)

return restart(*args, **kwargs)

trace_attributes = {"db.statement": sql}
observability_options = getattr(database, "observability_options", None)
Expand All @@ -551,7 +580,7 @@ def execute_sql(
# lock is added to handle the inline begin for first rpc
with self._lock:
return self._get_streamed_result_set(
restart,
wrapped_restart,
request,
trace_attributes,
column_info,
Expand All @@ -560,7 +589,7 @@ def execute_sql(
)
else:
return self._get_streamed_result_set(
restart,
wrapped_restart,
request,
trace_attributes,
column_info,
Expand Down Expand Up @@ -683,15 +712,27 @@ def partition_read(
trace_attributes,
observability_options=getattr(database, "observability_options", None),
):
method = functools.partial(
api.partition_read,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_method(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
)
method = functools.partial(
api.partition_read,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)
return method(*args, **kwargs)

response = _retry(
method,
wrapped_method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

Expand Down Expand Up @@ -786,15 +827,28 @@ def partition_query(
trace_attributes,
observability_options=getattr(database, "observability_options", None),
):
method = functools.partial(
api.partition_query,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_method(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
)

method = functools.partial(
api.partition_query,
request=request,
metadata=all_metadata,
retry=retry,
timeout=timeout,
)
return method(*args, **kwargs)

response = _retry(
method,
wrapped_method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

Expand Down Expand Up @@ -932,14 +986,27 @@ def begin(self):
self._session,
observability_options=getattr(database, "observability_options", None),
):
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_selector.begin,
metadata=metadata,
)
nth_request = getattr(database, "_next_nth_request", 0)
attempt = AtomicCounter(0)

def wrapped_method(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
)

method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_selector.begin,
metadata=all_metadata,
)
return method(*args, **kwargs)

response = _retry(
method,
wrapped_method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self._transaction_id = response.id
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_atomic_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest
from google.cloud.spanner_v1._helpers import AtomicCounter


class TestAtomicCounter(unittest.TestCase):
def test_initialization(self):
ac_default = AtomicCounter()
Expand Down Expand Up @@ -54,7 +55,6 @@ def test_plus_call(self):
assert n == 201
assert ac.value == 1


def test_multiple_threads_incrementing(self):
ac = AtomicCounter()
n = 200
Expand All @@ -78,4 +78,4 @@ def do_work():
assert th.is_alive() == False

# Finally the result should be n*m
assert ac.value == n*m
assert ac.value == n * m

0 comments on commit bcf868b

Please sign in to comment.