35
35
_merge_query_options ,
36
36
_metadata_with_prefix ,
37
37
_metadata_with_leader_aware_routing ,
38
+ _metadata_with_request_id ,
38
39
_retry ,
39
40
_check_rst_stream_error ,
40
41
_SessionWrapper ,
42
+ AtomicCounter ,
41
43
)
42
44
from google .cloud .spanner_v1 ._opentelemetry_tracing import trace_call
43
45
from google .cloud .spanner_v1 .streamed import StreamedResultSet
@@ -320,13 +322,26 @@ def read(
320
322
data_boost_enabled = data_boost_enabled ,
321
323
directed_read_options = directed_read_options ,
322
324
)
323
- restart = functools .partial (
324
- api .streaming_read ,
325
- request = request ,
326
- metadata = metadata ,
327
- retry = retry ,
328
- timeout = timeout ,
329
- )
325
+
326
+ nth_request = getattr (database , "_next_nth_request" , 0 )
327
+ attempt = AtomicCounter (0 )
328
+
329
+ def wrapped_restart (* args , ** kwargs ):
330
+ attempt .increment ()
331
+ channel_id = getattr (self ._session , "_channel_id" , 0 )
332
+ client_id = getattr (database , "_nth_client_id" , 0 )
333
+ all_metadata = _metadata_with_request_id (
334
+ client_id , channel_id , nth_request , attempt .value , metadata
335
+ )
336
+
337
+ restart = functools .partial (
338
+ api .streaming_read ,
339
+ request = request ,
340
+ metadata = all_metadata ,
341
+ retry = retry ,
342
+ timeout = timeout ,
343
+ )
344
+ return restart (* args , ** kwargs )
330
345
331
346
trace_attributes = {"table_id" : table , "columns" : columns }
332
347
observability_options = getattr (database , "observability_options" , None )
@@ -335,7 +350,7 @@ def read(
335
350
# lock is added to handle the inline begin for first rpc
336
351
with self ._lock :
337
352
iterator = _restart_on_unavailable (
338
- restart ,
353
+ wrapped_restart ,
339
354
request ,
340
355
"CloudSpanner.ReadOnlyTransaction" ,
341
356
self ._session ,
@@ -357,7 +372,7 @@ def read(
357
372
)
358
373
else :
359
374
iterator = _restart_on_unavailable (
360
- restart ,
375
+ wrapped_restart ,
361
376
request ,
362
377
"CloudSpanner.ReadOnlyTransaction" ,
363
378
self ._session ,
@@ -536,13 +551,27 @@ def execute_sql(
536
551
data_boost_enabled = data_boost_enabled ,
537
552
directed_read_options = directed_read_options ,
538
553
)
539
- restart = functools .partial (
540
- api .execute_streaming_sql ,
541
- request = request ,
542
- metadata = metadata ,
543
- retry = retry ,
544
- timeout = timeout ,
545
- )
554
+
555
+ nth_request = getattr (database , "_next_nth_request" , 0 )
556
+ attempt = AtomicCounter (0 )
557
+
558
+ def wrapped_restart (* args , ** kwargs ):
559
+ attempt .increment ()
560
+ channel_id = getattr (self ._session , "_channel_id" , 0 )
561
+ client_id = getattr (database , "_nth_client_id" , 0 )
562
+ all_metadata = _metadata_with_request_id (
563
+ client_id , channel_id , nth_request , attempt .value , metadata
564
+ )
565
+
566
+ restart = functools .partial (
567
+ api .execute_streaming_sql ,
568
+ request = request ,
569
+ metadata = all_metadata ,
570
+ retry = retry ,
571
+ timeout = timeout ,
572
+ )
573
+
574
+ return restart (* args , ** kwargs )
546
575
547
576
trace_attributes = {"db.statement" : sql }
548
577
observability_options = getattr (database , "observability_options" , None )
@@ -551,7 +580,7 @@ def execute_sql(
551
580
# lock is added to handle the inline begin for first rpc
552
581
with self ._lock :
553
582
return self ._get_streamed_result_set (
554
- restart ,
583
+ wrapped_restart ,
555
584
request ,
556
585
trace_attributes ,
557
586
column_info ,
@@ -560,7 +589,7 @@ def execute_sql(
560
589
)
561
590
else :
562
591
return self ._get_streamed_result_set (
563
- restart ,
592
+ wrapped_restart ,
564
593
request ,
565
594
trace_attributes ,
566
595
column_info ,
@@ -683,15 +712,27 @@ def partition_read(
683
712
trace_attributes ,
684
713
observability_options = getattr (database , "observability_options" , None ),
685
714
):
686
- method = functools .partial (
687
- api .partition_read ,
688
- request = request ,
689
- metadata = metadata ,
690
- retry = retry ,
691
- timeout = timeout ,
692
- )
715
+ nth_request = getattr (database , "_next_nth_request" , 0 )
716
+ attempt = AtomicCounter (0 )
717
+
718
+ def wrapped_method (* args , ** kwargs ):
719
+ attempt .increment ()
720
+ channel_id = getattr (self ._session , "_channel_id" , 0 )
721
+ client_id = getattr (database , "_nth_client_id" , 0 )
722
+ all_metadata = _metadata_with_request_id (
723
+ client_id , channel_id , nth_request , attempt .value , metadata
724
+ )
725
+ method = functools .partial (
726
+ api .partition_read ,
727
+ request = request ,
728
+ metadata = all_metadata ,
729
+ retry = retry ,
730
+ timeout = timeout ,
731
+ )
732
+ return method (* args , ** kwargs )
733
+
693
734
response = _retry (
694
- method ,
735
+ wrapped_method ,
695
736
allowed_exceptions = {InternalServerError : _check_rst_stream_error },
696
737
)
697
738
@@ -786,15 +827,28 @@ def partition_query(
786
827
trace_attributes ,
787
828
observability_options = getattr (database , "observability_options" , None ),
788
829
):
789
- method = functools .partial (
790
- api .partition_query ,
791
- request = request ,
792
- metadata = metadata ,
793
- retry = retry ,
794
- timeout = timeout ,
795
- )
830
+ nth_request = getattr (database , "_next_nth_request" , 0 )
831
+ attempt = AtomicCounter (0 )
832
+
833
+ def wrapped_method (* args , ** kwargs ):
834
+ attempt .increment ()
835
+ channel_id = getattr (self ._session , "_channel_id" , 0 )
836
+ client_id = getattr (database , "_nth_client_id" , 0 )
837
+ all_metadata = _metadata_with_request_id (
838
+ client_id , channel_id , nth_request , attempt .value , metadata
839
+ )
840
+
841
+ method = functools .partial (
842
+ api .partition_query ,
843
+ request = request ,
844
+ metadata = all_metadata ,
845
+ retry = retry ,
846
+ timeout = timeout ,
847
+ )
848
+ return method (* args , ** kwargs )
849
+
796
850
response = _retry (
797
- method ,
851
+ wrapped_method ,
798
852
allowed_exceptions = {InternalServerError : _check_rst_stream_error },
799
853
)
800
854
@@ -932,14 +986,27 @@ def begin(self):
932
986
self ._session ,
933
987
observability_options = getattr (database , "observability_options" , None ),
934
988
):
935
- method = functools .partial (
936
- api .begin_transaction ,
937
- session = self ._session .name ,
938
- options = txn_selector .begin ,
939
- metadata = metadata ,
940
- )
989
+ nth_request = getattr (database , "_next_nth_request" , 0 )
990
+ attempt = AtomicCounter (0 )
991
+
992
+ def wrapped_method (* args , ** kwargs ):
993
+ attempt .increment ()
994
+ channel_id = getattr (self ._session , "_channel_id" , 0 )
995
+ client_id = getattr (database , "_nth_client_id" , 0 )
996
+ all_metadata = _metadata_with_request_id (
997
+ client_id , channel_id , nth_request , attempt .value , metadata
998
+ )
999
+
1000
+ method = functools .partial (
1001
+ api .begin_transaction ,
1002
+ session = self ._session .name ,
1003
+ options = txn_selector .begin ,
1004
+ metadata = all_metadata ,
1005
+ )
1006
+ return method (* args , ** kwargs )
1007
+
941
1008
response = _retry (
942
- method ,
1009
+ wrapped_method ,
943
1010
allowed_exceptions = {InternalServerError : _check_rst_stream_error },
944
1011
)
945
1012
self ._transaction_id = response .id
0 commit comments