Skip to content

Commit

Permalink
Wire up XGoogSpannerRequestIdInterceptor for TestDatabase checks
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 16, 2024
1 parent db64d41 commit b221815
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 91 deletions.
34 changes: 30 additions & 4 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ class Database(object):

_spanner_api: SpannerClient = None

__transport_lock = threading.Lock()
__transports_to_channel_id = dict()

def __init__(
self,
database_id,
Expand Down Expand Up @@ -444,6 +447,31 @@ def spanner_api(self):
)
return self._spanner_api

@property
def _channel_id(self):
"""
Helper to retrieve the associated channelID for the spanner_api.
This property is paramount to x-goog-spanner-request-id.
"""
with self.__transport_lock:
api = self.spanner_api
channel_id = self.__transports_to_channel_id.get(api._transport, None)
if channel_id is None:
channel_id = len(self.__transports_to_channel_id) + 1
self.__transports_to_channel_id[api._transport] = channel_id

return channel_id

def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
client_id = self._nth_client_id
return _metadata_with_request_id(
self._nth_client_id,
self._channel_id,
nth_request,
nth_attempt,
prior_metadata,
)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
Expand Down Expand Up @@ -705,10 +733,8 @@ def execute_partitioned_dml(

def execute_pdml():
with SessionCheckout(self._pool) as session:
channel_id = getattr(session, "_channel_id", 0)
client_id = getattr(self, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
all_metadata = self.metadata_with_request_id(
nth_request, attempt.value, metadata
)
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=all_metadata
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ def database(
proto_descriptors=proto_descriptors,
)
else:
print("enabled interceptors")
return TestDatabase(
database_id,
self,
Expand Down
22 changes: 8 additions & 14 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import datetime
import queue
import time
import threading

from google.cloud.exceptions import NotFound
from google.cloud.spanner_v1 import BatchCreateSessionsRequest
Expand Down Expand Up @@ -53,8 +52,6 @@ def __init__(self, labels=None, database_role=None):
labels = {}
self._labels = labels
self._database_role = database_role
self.__lock = threading.Lock()
self._session_id_to_channel_id = dict()

@property
def labels(self):
Expand Down Expand Up @@ -130,19 +127,10 @@ def _new_session(self):
:rtype: :class:`~google.cloud.spanner_v1.session.Session`
:returns: new session instance.
"""
session = self._database.session(
return self._database.session(
labels=self.labels, database_role=self.database_role
)

session_id = getattr(session, "_session_id", None)
if session_id:
with self.__lock:
channel_id = len(self._session_id_to_channel_id) + 1
self._session_id_to_channel_id[session._session_id] = channel_id
session._channel_id = channel_id

return session

def session(self, **kwargs):
"""Check out a session from the pool.
Expand Down Expand Up @@ -257,10 +245,16 @@ def bind(self, database):
f"Creating {request.session_count} sessions",
span_event_attributes,
)

attempt = 1
all_metadata = database.metadata_with_request_id(
database._next_nth_request, attempt, metadata
)
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
metadata=all_metadata,
)

for session_pb in resp.session:
session = self._new_session()
session._session_id = session_pb.name.split("/")[-1]
Expand Down
1 change: 0 additions & 1 deletion google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def __init__(self, database, labels=None, database_role=None):
self._labels = labels
self._database_role = database_role
self._last_use_time = datetime.utcnow()
self.__channel_id = 0

def __lt__(self, other):
return self._session_id < other._session_id
Expand Down
32 changes: 10 additions & 22 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,8 @@ def read(

def wrapped_restart(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)

restart = functools.partial(
Expand Down Expand Up @@ -557,10 +555,8 @@ def execute_sql(

def wrapped_restart(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)

restart = functools.partial(
Expand Down Expand Up @@ -717,10 +713,8 @@ def partition_read(

def wrapped_method(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)
method = functools.partial(
api.partition_read,
Expand Down Expand Up @@ -832,12 +826,9 @@ def partition_query(

def wrapped_method(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)

method = functools.partial(
api.partition_query,
request=request,
Expand Down Expand Up @@ -991,12 +982,9 @@ def begin(self):

def wrapped_method(*args, **kwargs):
attempt.increment()
channel_id = getattr(self._session, "_channel_id", 0)
client_id = getattr(database, "_nth_client_id", 0)
all_metadata = _metadata_with_request_id(
client_id, channel_id, nth_request, attempt.value, metadata
all_metadata = database.metadata_with_request_id(
nth_request, attempt.value, metadata
)

method = functools.partial(
api.begin_transaction,
session=self._session.name,
Expand Down
6 changes: 4 additions & 2 deletions google/cloud/spanner_v1/testing/database_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class TestDatabase(Database):
currently, and we don't want to make changes in the Database class for
testing purpose as this is a hack to use interceptors in tests."""

_interceptors = []

def __init__(
self,
database_id,
Expand All @@ -61,11 +63,9 @@ def __init__(

self._method_count_interceptor = MethodCountInterceptor()
self._method_abort_interceptor = MethodAbortInterceptor()
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
self._interceptors = [
self._method_count_interceptor,
self._method_abort_interceptor,
self._x_goog_request_id_interceptor,
]

@property
Expand All @@ -77,6 +77,8 @@ def spanner_api(self):
client_options = client._client_options
if self._instance.emulator_host is not None:
channel = grpc.insecure_channel(self._instance.emulator_host)
self._x_goog_request_id_interceptor = XGoogRequestIDHeaderInterceptor()
self._interceptors.append(self._x_goog_request_id_interceptor)
channel = grpc.intercept_channel(channel, *self._interceptors)
transport = SpannerGrpcTransport(channel=channel)
self._spanner_api = SpannerClient(
Expand Down
39 changes: 33 additions & 6 deletions google/cloud/spanner_v1/testing/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,28 @@ def intercept(self, method, request_or_iterator, call_details):
break

if not x_goog_request_id:
raise Exception(f"Missing {x_goog_request_id}")

streaming = hasattr(request_or_iterator, "__iter__", False)
raise Exception("Missing x_goog_request_id header")

response_or_iterator = method(request_or_iterator, call_details)
streaming = getattr(response_or_iterator, "__iter__", None) is not None
print(
"intercept got",
x_goog_request_id,
call_details.method,
"streaming",
streaming,
)
with self.__lock:
if streaming:
self._stream_req_segments.append(x_goog_request_id)
self._stream_req_segments.append(
(call_details.method, parse_request_id(x_goog_request_id))
)
else:
self._unary_req_segments.append(x_goog_request_id)
self._unary_req_segments.append(
(call_details.method, parse_request_id(x_goog_request_id))
)

return method(request_or_iterator, call_details)
return response_or_iterator

@property
def unary_request_ids(self):
Expand All @@ -105,3 +117,18 @@ def reset(self):
self._stream_req_segments.clear()
self._unary_req_segments.clear()
pass


def parse_request_id(request_id_str):
splits = request_id_str.split(".")
version, rand_process_id, client_id, channel_id, nth_request, nth_attempt = list(
map(lambda v: int(v), splits)
)
return (
version,
rand_process_id,
client_id,
channel_id,
nth_request,
nth_attempt,
)
15 changes: 4 additions & 11 deletions tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def __init__(self, *args, **kwargs):
self._client = None
self._instance = None
self._database = None
self._interceptors = None

@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -147,19 +146,11 @@ def teardown_method(self, *args, **kwargs):
@property
def client(self) -> Client:
if self._client is None:
api_endpoint = "localhost:" + str(MockServerTestBase.port)
channel = grpc.insecure_channel(api_endpoint)
transport = None
if self._interceptors and len(self._interceptors) > 0:
channel = grpc.intercept_channel(channel, *self._interceptors)
transport = SpannerGrpcTransport(channel=channel)

self._client = Client(
project="p",
credentials=AnonymousCredentials(),
client_options=ClientOptions(
transport=transport,
api_endpoint=api_endpoint if transport is None else None,
api_endpoint="localhost:" + str(MockServerTestBase.port),
),
)
return self._client
Expand All @@ -174,6 +165,8 @@ def instance(self) -> Instance:
def database(self) -> Database:
if self._database is None:
self._database = self.instance.database(
"test-database", pool=FixedSizePool(size=10)
"test-database",
pool=FixedSizePool(size=10),
enable_interceptors_in_tests=True,
)
return self._database
Loading

0 comments on commit b221815

Please sign in to comment.