Skip to content

Commit bcf868b

Browse files
committed
Wrap Snapshot methods with x-goog-request-id metadata injector
1 parent 47da5e4 commit bcf868b

File tree

3 files changed

+112
-44
lines changed

3 files changed

+112
-44
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,8 @@ def execute_partitioned_dml(
700700
)
701701

702702
nth_request = getattr(self, "_next_nth_request", 0)
703-
attempt = AtomicCounter(1) # It'll be incremented inside _restart_on_unavailable
703+
# Attempt will be incremented inside _restart_on_unavailable.
704+
attempt = AtomicCounter(1)
704705

705706
def execute_pdml():
706707
with SessionCheckout(self._pool) as session:

google/cloud/spanner_v1/snapshot.py

Lines changed: 108 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
_merge_query_options,
3636
_metadata_with_prefix,
3737
_metadata_with_leader_aware_routing,
38+
_metadata_with_request_id,
3839
_retry,
3940
_check_rst_stream_error,
4041
_SessionWrapper,
42+
AtomicCounter,
4143
)
4244
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
4345
from google.cloud.spanner_v1.streamed import StreamedResultSet
@@ -320,13 +322,26 @@ def read(
320322
data_boost_enabled=data_boost_enabled,
321323
directed_read_options=directed_read_options,
322324
)
323-
restart = functools.partial(
324-
api.streaming_read,
325-
request=request,
326-
metadata=metadata,
327-
retry=retry,
328-
timeout=timeout,
329-
)
325+
326+
nth_request = getattr(database, "_next_nth_request", 0)
327+
attempt = AtomicCounter(0)
328+
329+
def wrapped_restart(*args, **kwargs):
330+
attempt.increment()
331+
channel_id = getattr(self._session, "_channel_id", 0)
332+
client_id = getattr(database, "_nth_client_id", 0)
333+
all_metadata = _metadata_with_request_id(
334+
client_id, channel_id, nth_request, attempt.value, metadata
335+
)
336+
337+
restart = functools.partial(
338+
api.streaming_read,
339+
request=request,
340+
metadata=all_metadata,
341+
retry=retry,
342+
timeout=timeout,
343+
)
344+
return restart(*args, **kwargs)
330345

331346
trace_attributes = {"table_id": table, "columns": columns}
332347
observability_options = getattr(database, "observability_options", None)
@@ -335,7 +350,7 @@ def read(
335350
# lock is added to handle the inline begin for first rpc
336351
with self._lock:
337352
iterator = _restart_on_unavailable(
338-
restart,
353+
wrapped_restart,
339354
request,
340355
"CloudSpanner.ReadOnlyTransaction",
341356
self._session,
@@ -357,7 +372,7 @@ def read(
357372
)
358373
else:
359374
iterator = _restart_on_unavailable(
360-
restart,
375+
wrapped_restart,
361376
request,
362377
"CloudSpanner.ReadOnlyTransaction",
363378
self._session,
@@ -536,13 +551,27 @@ def execute_sql(
536551
data_boost_enabled=data_boost_enabled,
537552
directed_read_options=directed_read_options,
538553
)
539-
restart = functools.partial(
540-
api.execute_streaming_sql,
541-
request=request,
542-
metadata=metadata,
543-
retry=retry,
544-
timeout=timeout,
545-
)
554+
555+
nth_request = getattr(database, "_next_nth_request", 0)
556+
attempt = AtomicCounter(0)
557+
558+
def wrapped_restart(*args, **kwargs):
559+
attempt.increment()
560+
channel_id = getattr(self._session, "_channel_id", 0)
561+
client_id = getattr(database, "_nth_client_id", 0)
562+
all_metadata = _metadata_with_request_id(
563+
client_id, channel_id, nth_request, attempt.value, metadata
564+
)
565+
566+
restart = functools.partial(
567+
api.execute_streaming_sql,
568+
request=request,
569+
metadata=all_metadata,
570+
retry=retry,
571+
timeout=timeout,
572+
)
573+
574+
return restart(*args, **kwargs)
546575

547576
trace_attributes = {"db.statement": sql}
548577
observability_options = getattr(database, "observability_options", None)
@@ -551,7 +580,7 @@ def execute_sql(
551580
# lock is added to handle the inline begin for first rpc
552581
with self._lock:
553582
return self._get_streamed_result_set(
554-
restart,
583+
wrapped_restart,
555584
request,
556585
trace_attributes,
557586
column_info,
@@ -560,7 +589,7 @@ def execute_sql(
560589
)
561590
else:
562591
return self._get_streamed_result_set(
563-
restart,
592+
wrapped_restart,
564593
request,
565594
trace_attributes,
566595
column_info,
@@ -683,15 +712,27 @@ def partition_read(
683712
trace_attributes,
684713
observability_options=getattr(database, "observability_options", None),
685714
):
686-
method = functools.partial(
687-
api.partition_read,
688-
request=request,
689-
metadata=metadata,
690-
retry=retry,
691-
timeout=timeout,
692-
)
715+
nth_request = getattr(database, "_next_nth_request", 0)
716+
attempt = AtomicCounter(0)
717+
718+
def wrapped_method(*args, **kwargs):
719+
attempt.increment()
720+
channel_id = getattr(self._session, "_channel_id", 0)
721+
client_id = getattr(database, "_nth_client_id", 0)
722+
all_metadata = _metadata_with_request_id(
723+
client_id, channel_id, nth_request, attempt.value, metadata
724+
)
725+
method = functools.partial(
726+
api.partition_read,
727+
request=request,
728+
metadata=all_metadata,
729+
retry=retry,
730+
timeout=timeout,
731+
)
732+
return method(*args, **kwargs)
733+
693734
response = _retry(
694-
method,
735+
wrapped_method,
695736
allowed_exceptions={InternalServerError: _check_rst_stream_error},
696737
)
697738

@@ -786,15 +827,28 @@ def partition_query(
786827
trace_attributes,
787828
observability_options=getattr(database, "observability_options", None),
788829
):
789-
method = functools.partial(
790-
api.partition_query,
791-
request=request,
792-
metadata=metadata,
793-
retry=retry,
794-
timeout=timeout,
795-
)
830+
nth_request = getattr(database, "_next_nth_request", 0)
831+
attempt = AtomicCounter(0)
832+
833+
def wrapped_method(*args, **kwargs):
834+
attempt.increment()
835+
channel_id = getattr(self._session, "_channel_id", 0)
836+
client_id = getattr(database, "_nth_client_id", 0)
837+
all_metadata = _metadata_with_request_id(
838+
client_id, channel_id, nth_request, attempt.value, metadata
839+
)
840+
841+
method = functools.partial(
842+
api.partition_query,
843+
request=request,
844+
metadata=all_metadata,
845+
retry=retry,
846+
timeout=timeout,
847+
)
848+
return method(*args, **kwargs)
849+
796850
response = _retry(
797-
method,
851+
wrapped_method,
798852
allowed_exceptions={InternalServerError: _check_rst_stream_error},
799853
)
800854

@@ -932,14 +986,27 @@ def begin(self):
932986
self._session,
933987
observability_options=getattr(database, "observability_options", None),
934988
):
935-
method = functools.partial(
936-
api.begin_transaction,
937-
session=self._session.name,
938-
options=txn_selector.begin,
939-
metadata=metadata,
940-
)
989+
nth_request = getattr(database, "_next_nth_request", 0)
990+
attempt = AtomicCounter(0)
991+
992+
def wrapped_method(*args, **kwargs):
993+
attempt.increment()
994+
channel_id = getattr(self._session, "_channel_id", 0)
995+
client_id = getattr(database, "_nth_client_id", 0)
996+
all_metadata = _metadata_with_request_id(
997+
client_id, channel_id, nth_request, attempt.value, metadata
998+
)
999+
1000+
method = functools.partial(
1001+
api.begin_transaction,
1002+
session=self._session.name,
1003+
options=txn_selector.begin,
1004+
metadata=all_metadata,
1005+
)
1006+
return method(*args, **kwargs)
1007+
9411008
response = _retry(
942-
method,
1009+
wrapped_method,
9431010
allowed_exceptions={InternalServerError: _check_rst_stream_error},
9441011
)
9451012
self._transaction_id = response.id

tests/unit/test_atomic_counter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import unittest
1919
from google.cloud.spanner_v1._helpers import AtomicCounter
2020

21+
2122
class TestAtomicCounter(unittest.TestCase):
2223
def test_initialization(self):
2324
ac_default = AtomicCounter()
@@ -54,7 +55,6 @@ def test_plus_call(self):
5455
assert n == 201
5556
assert ac.value == 1
5657

57-
5858
def test_multiple_threads_incrementing(self):
5959
ac = AtomicCounter()
6060
n = 200
@@ -78,4 +78,4 @@ def do_work():
7878
assert th.is_alive() == False
7979

8080
# Finally the result should be n*m
81-
assert ac.value == n*m
81+
assert ac.value == n * m

0 commit comments

Comments
 (0)