Skip to content

Commit

Permalink
Inject header in more Session using spots plus more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 16, 2024
1 parent b221815 commit caa60c2
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 7 deletions.
21 changes: 16 additions & 5 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def exists(self):
current_span, "Checking if Session exists", {"session.id": self._session_id}
)

api = self._database.spanner_api
database = self._database
api = database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
if self._database._route_to_leader_enabled:
metadata.append(
Expand All @@ -202,12 +203,16 @@ def exists(self):
)
)

all_metadata = database.metadata_with_request_id(
database._next_nth_request, 1, metadata
)

observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.GetSession", self, observability_options=observability_options
) as span:
try:
api.get_session(name=self.name, metadata=metadata)
api.get_session(name=self.name, metadata=all_metadata)
if span:
span.set_attribute("session_found", True)
except NotFound:
Expand Down Expand Up @@ -237,8 +242,11 @@ def delete(self):
current_span, "Deleting Session", {"session.id": self._session_id}
)

api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
database = self._database
api = database.spanner_api
metadata = database.metadata_with_request_id(
database._next_nth_request, 1, _metadata_with_prefix(database.name)
)
observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.DeleteSession",
Expand All @@ -255,7 +263,10 @@ def ping(self):
if self._session_id is None:
raise ValueError("Session ID not set by back-end")
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
database = self._database
metadata = database.metadata_with_request_id(
database._next_nth_request, 1, _metadata_with_prefix(database.name)
)
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
api.execute_sql(request=request, metadata=metadata)
self._last_use_time = datetime.now()
Expand Down
9 changes: 7 additions & 2 deletions google/cloud/spanner_v1/testing/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def reset(self):
self._connection = None


X_GOOG_REQUEST_ID = "x-goog-spanner-request-id"


class XGoogRequestIDHeaderInterceptor(ClientInterceptor):
def __init__(self):
self._unary_req_segments = []
Expand All @@ -77,12 +80,14 @@ def intercept(self, method, request_or_iterator, call_details):
metadata = call_details.metadata
x_goog_request_id = None
for key, value in metadata:
if key == "x-goog-spanner-request-id":
if key == X_GOOG_REQUEST_ID:
x_goog_request_id = value
break

if not x_goog_request_id:
raise Exception("Missing x_goog_request_id header")
raise Exception(
f"Missing {X_GOOG_REQUEST_ID} header in {call_details.method}"
)

response_or_iterator = method(request_or_iterator, call_details)
streaming = getattr(response_or_iterator, "__iter__", None) is not None
Expand Down
130 changes: 130 additions & 0 deletions tests/unit/test_request_id_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import threading

from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_select1_result,
Expand Down Expand Up @@ -63,6 +66,133 @@ def test_snapshot_read(self):
assert got_unary_segments == want_unary_segments
assert got_stream_segments == want_stream_segments

def test_snapshot_read_concurrent(self):
def select1():
with self.database.snapshot() as snapshot:
rows = snapshot.execute_sql("select 1")
res_list = []
for row in rows:
self.assertEqual(1, row[0])
res_list.append(row)
self.assertEqual(1, len(res_list))

n = 10
threads = []
for i in range(n):
th = threading.Thread(target=select1, name=f"snapshot-select1-{i}")
th.run()
threads.append(th)

random.shuffle(threads)

while True:
n_finished = 0
for thread in threads:
if thread.is_alive():
thread.join()
else:
n_finished += 1

if n_finished == len(threads):
break

time.sleep(1)

requests = self.spanner_service.requests
self.assertEqual(n * 2, len(requests), msg=requests)

client_id = self.database._nth_client_id
channel_id = self.database._channel_id
got_stream_segments, got_unary_segments = self.canonicalize_request_id_headers()

want_unary_segments = [
(
"/google.spanner.v1.Spanner/BatchCreateSessions",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 1, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 3, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 5, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 7, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 9, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 11, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 13, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 15, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 17, 1),
),
(
"/google.spanner.v1.Spanner/GetSession",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 19, 1),
),
]
assert got_unary_segments == want_unary_segments

want_stream_segments = [
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 2, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 4, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 6, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 8, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 10, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 12, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 14, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 16, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 18, 1),
),
(
"/google.spanner.v1.Spanner/ExecuteStreamingSql",
(1, REQ_RAND_PROCESS_ID, client_id, channel_id, 20, 1),
),
]
assert got_stream_segments == want_stream_segments

def canonicalize_request_id_headers(self):
src = self.database._x_goog_request_id_interceptor
return src._stream_req_segments, src._unary_req_segments

0 comments on commit caa60c2

Please sign in to comment.