Skip to content

Commit b221815

Browse files
committed
Wire up XGoogSpannerRequestIdInterceptor for TestDatabase checks
1 parent db64d41 commit b221815

File tree

9 files changed

+117
-91
lines changed

9 files changed

+117
-91
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ class Database(object):
150150

151151
_spanner_api: SpannerClient = None
152152

153+
__transport_lock = threading.Lock()
154+
__transports_to_channel_id = dict()
155+
153156
def __init__(
154157
self,
155158
database_id,
@@ -444,6 +447,31 @@ def spanner_api(self):
444447
)
445448
return self._spanner_api
446449

450+
@property
451+
def _channel_id(self):
452+
"""
453+
Helper to retrieve the associated channelID for the spanner_api.
454+
This property is paramount to x-goog-spanner-request-id.
455+
"""
456+
with self.__transport_lock:
457+
api = self.spanner_api
458+
channel_id = self.__transports_to_channel_id.get(api._transport, None)
459+
if channel_id is None:
460+
channel_id = len(self.__transports_to_channel_id) + 1
461+
self.__transports_to_channel_id[api._transport] = channel_id
462+
463+
return channel_id
464+
465+
def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
466+
client_id = self._nth_client_id
467+
return _metadata_with_request_id(
468+
self._nth_client_id,
469+
self._channel_id,
470+
nth_request,
471+
nth_attempt,
472+
prior_metadata,
473+
)
474+
447475
def __eq__(self, other):
448476
if not isinstance(other, self.__class__):
449477
return NotImplemented
@@ -705,10 +733,8 @@ def execute_partitioned_dml(
705733

706734
def execute_pdml():
707735
with SessionCheckout(self._pool) as session:
708-
channel_id = getattr(session, "_channel_id", 0)
709-
client_id = getattr(self, "_nth_client_id", 0)
710-
all_metadata = _metadata_with_request_id(
711-
client_id, channel_id, nth_request, attempt.value, metadata
736+
all_metadata = self.metadata_with_request_id(
737+
nth_request, attempt.value, metadata
712738
)
713739
txn = api.begin_transaction(
714740
session=session.name, options=txn_options, metadata=all_metadata

google/cloud/spanner_v1/instance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ def database(
501501
proto_descriptors=proto_descriptors,
502502
)
503503
else:
504+
print("enabled interceptors")
504505
return TestDatabase(
505506
database_id,
506507
self,

google/cloud/spanner_v1/pool.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import datetime
1818
import queue
1919
import time
20-
import threading
2120

2221
from google.cloud.exceptions import NotFound
2322
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
@@ -53,8 +52,6 @@ def __init__(self, labels=None, database_role=None):
5352
labels = {}
5453
self._labels = labels
5554
self._database_role = database_role
56-
self.__lock = threading.Lock()
57-
self._session_id_to_channel_id = dict()
5855

5956
@property
6057
def labels(self):
@@ -130,19 +127,10 @@ def _new_session(self):
130127
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
131128
:returns: new session instance.
132129
"""
133-
session = self._database.session(
130+
return self._database.session(
134131
labels=self.labels, database_role=self.database_role
135132
)
136133

137-
session_id = getattr(session, "_session_id", None)
138-
if session_id:
139-
with self.__lock:
140-
channel_id = len(self._session_id_to_channel_id) + 1
141-
self._session_id_to_channel_id[session._session_id] = channel_id
142-
session._channel_id = channel_id
143-
144-
return session
145-
146134
def session(self, **kwargs):
147135
"""Check out a session from the pool.
148136
@@ -257,10 +245,16 @@ def bind(self, database):
257245
f"Creating {request.session_count} sessions",
258246
span_event_attributes,
259247
)
248+
249+
attempt = 1
250+
all_metadata = database.metadata_with_request_id(
251+
database._next_nth_request, attempt, metadata
252+
)
260253
resp = api.batch_create_sessions(
261254
request=request,
262-
metadata=metadata,
255+
metadata=all_metadata,
263256
)
257+
264258
for session_pb in resp.session:
265259
session = self._new_session()
266260
session._session_id = session_pb.name.split("/")[-1]

google/cloud/spanner_v1/session.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def __init__(self, database, labels=None, database_role=None):
7575
self._labels = labels
7676
self._database_role = database_role
7777
self._last_use_time = datetime.utcnow()
78-
self.__channel_id = 0
7978

8079
def __lt__(self, other):
8180
return self._session_id < other._session_id

google/cloud/spanner_v1/snapshot.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -328,10 +328,8 @@ def read(
328328

329329
def wrapped_restart(*args, **kwargs):
330330
attempt.increment()
331-
channel_id = getattr(self._session, "_channel_id", 0)
332-
client_id = getattr(database, "_nth_client_id", 0)
333-
all_metadata = _metadata_with_request_id(
334-
client_id, channel_id, nth_request, attempt.value, metadata
331+
all_metadata = database.metadata_with_request_id(
332+
nth_request, attempt.value, metadata
335333
)
336334

337335
restart = functools.partial(
@@ -557,10 +555,8 @@ def execute_sql(
557555

558556
def wrapped_restart(*args, **kwargs):
559557
attempt.increment()
560-
channel_id = getattr(self._session, "_channel_id", 0)
561-
client_id = getattr(database, "_nth_client_id", 0)
562-
all_metadata = _metadata_with_request_id(
563-
client_id, channel_id, nth_request, attempt.value, metadata
558+
all_metadata = database.metadata_with_request_id(
559+
nth_request, attempt.value, metadata
564560
)
565561

566562
restart = functools.partial(
@@ -717,10 +713,8 @@ def partition_read(
717713

718714
def wrapped_method(*args, **kwargs):
719715
attempt.increment()
720-
channel_id = getattr(self._session, "_channel_id", 0)
721-
client_id = getattr(database, "_nth_client_id", 0)
722-
all_metadata = _metadata_with_request_id(
723-
client_id, channel_id, nth_request, attempt.value, metadata
716+
all_metadata = database.metadata_with_request_id(
717+
nth_request, attempt.value, metadata
724718
)
725719
method = functools.partial(
726720
api.partition_read,
@@ -832,12 +826,9 @@ def partition_query(
832826

833827
def wrapped_method(*args, **kwargs):
834828
attempt.increment()
835-
channel_id = getattr(self._session, "_channel_id", 0)
836-
client_id = getattr(database, "_nth_client_id", 0)
837-
all_metadata = _metadata_with_request_id(
838-
client_id, channel_id, nth_request, attempt.value, metadata
829+
all_metadata = database.metadata_with_request_id(
830+
nth_request, attempt.value, metadata
839831
)
840-
841832
method = functools.partial(
842833
api.partition_query,
843834
request=request,
@@ -991,12 +982,9 @@ def begin(self):
991982

992983
def wrapped_method(*args, **kwargs):
993984
attempt.increment()
994-
channel_id = getattr(self._session, "_channel_id", 0)
995-
client_id = getattr(database, "_nth_client_id", 0)
996-
all_metadata = _metadata_with_request_id(
997-
client_id, channel_id, nth_request, attempt.value, metadata
985+
all_metadata = database.metadata_with_request_id(
986+
nth_request, attempt.value, metadata
998987
)
999-
1000988
method = functools.partial(
1001989
api.begin_transaction,
1002990
session=self._session.name,

google/cloud/spanner_v1/testing/database_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ class TestDatabase(Database):
3535
currently, and we don't want to make changes in the Database class for
3636
testing purpose as this is a hack to use interceptors in tests."""
3737

38+
_interceptors = []
39+
3840
def __init__(
3941
self,
4042
database_id,
@@ -61,11 +63,9 @@ def __init__(
6163

6264
self._method_count_interceptor = MethodCountInterceptor()
6365
self._method_abort_interceptor = MethodAbortInterceptor()
64-
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
6566
self._interceptors = [
6667
self._method_count_interceptor,
6768
self._method_abort_interceptor,
68-
self._x_goog_request_id_interceptor,
6969
]
7070

7171
@property
@@ -77,6 +77,8 @@ def spanner_api(self):
7777
client_options = client._client_options
7878
if self._instance.emulator_host is not None:
7979
channel = grpc.insecure_channel(self._instance.emulator_host)
80+
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
81+
self._interceptors.append(self._x_goog_request_id_interceptor)
8082
channel = grpc.intercept_channel(channel, *self._interceptors)
8183
transport = SpannerGrpcTransport(channel=channel)
8284
self._spanner_api = SpannerClient(

google/cloud/spanner_v1/testing/interceptors.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,28 @@ def intercept(self, method, request_or_iterator, call_details):
8282
break
8383

8484
if not x_goog_request_id:
85-
raise Exception(f"Missing {x_goog_request_id}")
86-
87-
streaming = hasattr(request_or_iterator, "__iter__", False)
85+
raise Exception("Missing x_goog_request_id header")
86+
87+
response_or_iterator = method(request_or_iterator, call_details)
88+
streaming = getattr(response_or_iterator, "__iter__", None) is not None
89+
print(
90+
"intercept got",
91+
x_goog_request_id,
92+
call_details.method,
93+
"streaming",
94+
streaming,
95+
)
8896
with self.__lock:
8997
if streaming:
90-
self._stream_req_segments.append(x_goog_request_id)
98+
self._stream_req_segments.append(
99+
(call_details.method, parse_request_id(x_goog_request_id))
100+
)
91101
else:
92-
self._unary_req_segments.append(x_goog_request_id)
102+
self._unary_req_segments.append(
103+
(call_details.method, parse_request_id(x_goog_request_id))
104+
)
93105

94-
return method(request_or_iterator, call_details)
106+
return response_or_iterator
95107

96108
@property
97109
def unary_request_ids(self):
@@ -105,3 +117,18 @@ def reset(self):
105117
self._stream_req_segments.clear()
106118
self._unary_req_segments.clear()
107119
pass
120+
121+
122+
def parse_request_id(request_id_str):
123+
splits = request_id_str.split(".")
124+
version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list(
125+
map(lambda v: int(v), splits)
126+
)
127+
return (
128+
version,
129+
rand_process_id,
130+
client_id,
131+
channel_id,
132+
nth_request,
133+
nth_attempt,
134+
)

tests/mockserver_tests/mock_server_test_base.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,6 @@ def __init__(self, *args, **kwargs):
118118
self._client = None
119119
self._instance = None
120120
self._database = None
121-
self._interceptors = None
122121

123122
@classmethod
124123
def setup_class(cls):
@@ -147,19 +146,11 @@ def teardown_method(self, *args, **kwargs):
147146
@property
148147
def client(self) -> Client:
149148
if self._client is None:
150-
api_endpoint = "localhost:" + str(MockServerTestBase.port)
151-
channel = grpc.insecure_channel(api_endpoint)
152-
transport = None
153-
if self._interceptors and len(self._interceptors) > 0:
154-
channel = grpc.intercept_channel(channel, *self._interceptors)
155-
transport = SpannerGrpcTransport(channel=channel)
156-
157149
self._client = Client(
158150
project="p",
159151
credentials=AnonymousCredentials(),
160152
client_options=ClientOptions(
161-
transport=transport,
162-
api_endpoint=api_endpoint if transport is None else None,
153+
api_endpoint="localhost:" + str(MockServerTestBase.port),
163154
),
164155
)
165156
return self._client
@@ -174,6 +165,8 @@ def instance(self) -> Instance:
174165
def database(self) -> Database:
175166
if self._database is None:
176167
self._database = self.instance.database(
177-
"test-database", pool=FixedSizePool(size=10)
168+
"test-database",
169+
pool=FixedSizePool(size=10),
170+
enable_interceptors_in_tests=True,
178171
)
179172
return self._database

0 commit comments

Comments
 (0)