From a6811afefa6739caa20203048635d94f9b85c4c8 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Fri, 6 Dec 2024 02:01:15 -0800 Subject: [PATCH] observability: annotate Session+SessionPool events (#1207) This change adds annotations for session and session pool events to aid customers in debugging latency issues with session pool malevolence and also for maintainers to figure out which session pool type is the most appropriate. Updates #1170 --- google/cloud/spanner_v1/_helpers.py | 4 + .../spanner_v1/_opentelemetry_tracing.py | 19 +- google/cloud/spanner_v1/database.py | 12 + google/cloud/spanner_v1/pool.py | 173 ++++++- google/cloud/spanner_v1/session.py | 28 +- google/cloud/spanner_v1/transaction.py | 32 +- tests/_helpers.py | 39 +- tests/unit/test_batch.py | 4 + tests/unit/test_database.py | 4 + tests/unit/test_pool.py | 438 ++++++++++-------- tests/unit/test_session.py | 38 ++ tests/unit/test_snapshot.py | 4 + tests/unit/test_spanner.py | 4 + tests/unit/test_transaction.py | 4 + 14 files changed, 602 insertions(+), 201 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index a4d66fc20f..29bd604e7b 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -463,6 +463,7 @@ def _retry( retry_count=5, delay=2, allowed_exceptions=None, + beforeNextRetry=None, ): """ Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions. @@ -479,6 +480,9 @@ def _retry( """ retries = 0 while retries <= retry_count: + if retries > 0 and beforeNextRetry: + beforeNextRetry(retries, delay) + try: return func() except Exception as exc: diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index e5aad08c05..1caac59ecd 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -81,10 +81,11 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) tracer = get_tracer(tracer_provider) # Set base attributes that we know for every trace created + db = session._database attributes = { "db.type": "spanner", "db.url": SpannerClient.DEFAULT_ENDPOINT, - "db.instance": session._database.name, + "db.instance": "" if not db else db.name, "net.host.name": SpannerClient.DEFAULT_ENDPOINT, OTEL_SCOPE_NAME: TRACER_NAME, OTEL_SCOPE_VERSION: TRACER_VERSION, @@ -106,7 +107,10 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) yield span except Exception as error: span.set_status(Status(StatusCode.ERROR, str(error))) - span.record_exception(error) + # OpenTelemetry-Python imposes invoking span.record_exception on __exit__ + # on any exception. We should file a bug later on with them to only + # invoke .record_exception if not already invoked, hence we should not + # invoke .record_exception on our own else we shall have 2 exceptions. raise else: if (not span._status) or span._status.status_code == StatusCode.UNSET: @@ -116,3 +120,14 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) # it wasn't previously set otherwise. # https://github.com/googleapis/python-spanner/issues/1246 span.set_status(Status(StatusCode.OK)) + + +def get_current_span(): + if not HAS_OPENTELEMETRY_INSTALLED: + return None + return trace.get_current_span() + + +def add_span_event(span, event_name, event_attributes=None): + if span: + span.add_event(event_name, event_attributes) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 1e10e1df73..c8230ab503 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -67,6 +67,10 @@ SpannerGrpcTransport, ) from google.cloud.spanner_v1.table import Table +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_span_event, + get_current_span, +) SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data" @@ -1164,7 +1168,9 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" + current_span = get_current_span() session = self._session = self._database._pool.get() + add_span_event(current_span, "Using session", {"id": session.session_id}) batch = self._batch = Batch(session) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag @@ -1187,6 +1193,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): extra={"commit_stats": self._batch.commit_stats}, ) self._database._pool.put(self._session) + current_span = get_current_span() + add_span_event( + current_span, + "Returned session to pool", + {"id": self._session.session_id}, + ) class MutationGroupsCheckout(object): diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index c95ef7a7b9..4f90196b4a 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -16,6 +16,7 @@ import datetime import queue +import time from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest @@ -24,6 +25,10 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_span_event, + get_current_span, +) from warnings import warn _NOW = datetime.datetime.utcnow # unit tests may replace @@ -196,6 +201,18 @@ def bind(self, database): when needed. """ self._database = database + requested_session_count = self.size - self._sessions.qsize() + span = get_current_span() + span_event_attributes = {"kind": type(self).__name__} + + if requested_session_count <= 0: + add_span_event( + span, + f"Invalid session pool size({requested_session_count}) <= 0", + span_event_attributes, + ) + return + api = database.spanner_api metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: @@ -203,13 +220,31 @@ def bind(self, database): _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) self._database_role = self._database_role or self._database.database_role + if requested_session_count > 0: + add_span_event( + span, + f"Requesting {requested_session_count} sessions", + span_event_attributes, + ) + + if self._sessions.full(): + add_span_event(span, "Session pool is already full", span_event_attributes) + return + request = BatchCreateSessionsRequest( database=database.name, - session_count=self.size - self._sessions.qsize(), + session_count=requested_session_count, session_template=Session(creator_role=self.database_role), ) + returned_session_count = 0 while not self._sessions.full(): + request.session_count = requested_session_count - self._sessions.qsize() + add_span_event( + span, + f"Creating {request.session_count} sessions", + span_event_attributes, + ) resp = api.batch_create_sessions( request=request, metadata=metadata, @@ -218,6 +253,13 @@ def bind(self, database): session = self._new_session() session._session_id = session_pb.name.split("/")[-1] self._sessions.put(session) + returned_session_count += 1 + + add_span_event( + span, + f"Requested for {requested_session_count} sessions, returned {returned_session_count}", + span_event_attributes, + ) def get(self, timeout=None): """Check a session out from the pool. @@ -233,12 +275,43 @@ def get(self, timeout=None): if timeout is None: timeout = self.default_timeout - session = self._sessions.get(block=True, timeout=timeout) - age = _NOW() - session.last_use_time + start_time = time.time() + current_span = get_current_span() + span_event_attributes = {"kind": type(self).__name__} + add_span_event(current_span, "Acquiring session", span_event_attributes) - if age >= self._max_age and not session.exists(): - session = self._database.session() - session.create() + session = None + try: + add_span_event( + current_span, + "Waiting for a session to become available", + span_event_attributes, + ) + + session = self._sessions.get(block=True, timeout=timeout) + age = _NOW() - session.last_use_time + + if age >= self._max_age and not session.exists(): + if not session.exists(): + add_span_event( + current_span, + "Session is not valid, recreating it", + span_event_attributes, + ) + session = self._database.session() + session.create() + # Replacing with the updated session.id. + span_event_attributes["session.id"] = session._session_id + + span_event_attributes["session.id"] = session._session_id + span_event_attributes["time.elapsed"] = time.time() - start_time + add_span_event(current_span, "Acquired session", span_event_attributes) + + except queue.Empty as e: + add_span_event( + current_span, "No sessions available in the pool", span_event_attributes + ) + raise e return session @@ -312,13 +385,32 @@ def get(self): :returns: an existing session from the pool, or a newly-created session. """ + current_span = get_current_span() + span_event_attributes = {"kind": type(self).__name__} + add_span_event(current_span, "Acquiring session", span_event_attributes) + try: + add_span_event( + current_span, + "Waiting for a session to become available", + span_event_attributes, + ) session = self._sessions.get_nowait() except queue.Empty: + add_span_event( + current_span, + "No sessions available in pool. Creating session", + span_event_attributes, + ) session = self._new_session() session.create() else: if not session.exists(): + add_span_event( + current_span, + "Session is not valid, recreating it", + span_event_attributes, + ) session = self._new_session() session.create() return session @@ -427,6 +519,38 @@ def bind(self, database): session_template=Session(creator_role=self.database_role), ) + span_event_attributes = {"kind": type(self).__name__} + current_span = get_current_span() + requested_session_count = request.session_count + if requested_session_count <= 0: + add_span_event( + current_span, + f"Invalid session pool size({requested_session_count}) <= 0", + span_event_attributes, + ) + return + + add_span_event( + current_span, + f"Requesting {requested_session_count} sessions", + span_event_attributes, + ) + + if created_session_count >= self.size: + add_span_event( + current_span, + "Created no new sessions as sessionPool is full", + span_event_attributes, + ) + return + + add_span_event( + current_span, + f"Creating {request.session_count} sessions", + span_event_attributes, + ) + + returned_session_count = 0 while created_session_count < self.size: resp = api.batch_create_sessions( request=request, @@ -436,8 +560,16 @@ def bind(self, database): session = self._new_session() session._session_id = session_pb.name.split("/")[-1] self.put(session) + returned_session_count += 1 + created_session_count += len(resp.session) + add_span_event( + current_span, + f"Requested for {requested_session_count} sessions, return {returned_session_count}", + span_event_attributes, + ) + def get(self, timeout=None): """Check a session out from the pool. @@ -452,7 +584,26 @@ def get(self, timeout=None): if timeout is None: timeout = self.default_timeout - ping_after, session = self._sessions.get(block=True, timeout=timeout) + start_time = time.time() + span_event_attributes = {"kind": type(self).__name__} + current_span = get_current_span() + add_span_event( + current_span, + "Waiting for a session to become available", + span_event_attributes, + ) + + ping_after = None + session = None + try: + ping_after, session = self._sessions.get(block=True, timeout=timeout) + except queue.Empty as e: + add_span_event( + current_span, + "No sessions available in the pool within the specified timeout", + span_event_attributes, + ) + raise e if _NOW() > ping_after: # Using session.exists() guarantees the returned session exists. @@ -462,6 +613,14 @@ def get(self, timeout=None): session = self._new_session() session.create() + span_event_attributes.update( + { + "time.elapsed": time.time() - start_time, + "session.id": session._session_id, + "kind": "pinging_pool", + } + ) + add_span_event(current_span, "Acquired session", span_event_attributes) return session def put(self, session): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 539f36af2b..166d5488c6 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -31,7 +31,11 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_span_event, + get_current_span, + trace_call, +) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_v1.transaction import Transaction @@ -134,6 +138,9 @@ def create(self): :raises ValueError: if :attr:`session_id` is already set. """ + current_span = get_current_span() + add_span_event(current_span, "Creating Session") + if self._session_id is not None: raise ValueError("Session ID already set by back-end") api = self._database.spanner_api @@ -174,8 +181,18 @@ def exists(self): :rtype: bool :returns: True if the session exists on the back-end, else False. """ + current_span = get_current_span() if self._session_id is None: + add_span_event( + current_span, + "Checking session existence: Session does not exist as it has not been created yet", + ) return False + + add_span_event( + current_span, "Checking if Session exists", {"session.id": self._session_id} + ) + api = self._database.spanner_api metadata = _metadata_with_prefix(self._database.name) if self._database._route_to_leader_enabled: @@ -209,8 +226,17 @@ def delete(self): :raises ValueError: if :attr:`session_id` is not already set. :raises NotFound: if the session does not exist """ + current_span = get_current_span() if self._session_id is None: + add_span_event( + current_span, "Deleting Session failed due to unset session_id" + ) raise ValueError("Session ID not set by back-end") + + add_span_event( + current_span, "Deleting Session", {"session.id": self._session_id} + ) + api = self._database.spanner_api metadata = _metadata_with_prefix(self._database.name) observability_options = getattr(self._database, "observability_options", None) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index d99c4fde2f..fa8e5121ff 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -32,7 +32,7 @@ from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import add_span_event, trace_call from google.cloud.spanner_v1 import RequestOptions from google.api_core import gapic_v1 from google.api_core.exceptions import InternalServerError @@ -160,16 +160,25 @@ def begin(self): "CloudSpanner.BeginTransaction", self._session, observability_options=observability_options, - ): + ) as span: method = functools.partial( api.begin_transaction, session=self._session.name, options=txn_options, metadata=metadata, ) + + def beforeNextRetry(nthRetry, delayInSeconds): + add_span_event( + span, + "Transaction Begin Attempt Failed. Retrying", + {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + ) + response = _retry( method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, + beforeNextRetry=beforeNextRetry, ) self._transaction_id = response.id return self._transaction_id @@ -246,7 +255,6 @@ def commit( metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - trace_attributes = {"num_mutations": len(self._mutations)} if request_options is None: request_options = RequestOptions() @@ -266,22 +274,38 @@ def commit( max_commit_delay=max_commit_delay, request_options=request_options, ) + + trace_attributes = {"num_mutations": len(self._mutations)} observability_options = getattr(database, "observability_options", None) with trace_call( "CloudSpanner.Commit", self._session, trace_attributes, observability_options, - ): + ) as span: + add_span_event(span, "Starting Commit") + method = functools.partial( api.commit, request=request, metadata=metadata, ) + + def beforeNextRetry(nthRetry, delayInSeconds): + add_span_event( + span, + "Transaction Commit Attempt Failed. Retrying", + {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + ) + response = _retry( method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, + beforeNextRetry=beforeNextRetry, ) + + add_span_event(span, "Commit Done") + self.committed = response.commit_timestamp if return_commit_stats: self.commit_stats = response.commit_stats diff --git a/tests/_helpers.py b/tests/_helpers.py index 5e514f2586..81787c5a86 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -16,10 +16,11 @@ OTEL_SCOPE_NAME, OTEL_SCOPE_VERSION, ) + from opentelemetry.sdk.trace.sampling import TraceIdRatioBased from opentelemetry.trace.status import StatusCode - trace.set_tracer_provider(TracerProvider()) + trace.set_tracer_provider(TracerProvider(sampler=TraceIdRatioBased(1.0))) HAS_OPENTELEMETRY_INSTALLED = True except ImportError: @@ -86,9 +87,43 @@ def assertSpanAttributes( if HAS_OPENTELEMETRY_INSTALLED: if not span: span_list = self.ot_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) + self.assertEqual(len(span_list) > 0, True) span = span_list[0] self.assertEqual(span.name, name) self.assertEqual(span.status.status_code, status) self.assertEqual(dict(span.attributes), attributes) + + def assertSpanEvents(self, name, wantEventNames=[], span=None): + if not HAS_OPENTELEMETRY_INSTALLED: + return + + if not span: + span_list = self.ot_exporter.get_finished_spans() + self.assertEqual(len(span_list) > 0, True) + span = span_list[0] + + self.assertEqual(span.name, name) + actualEventNames = [] + for event in span.events: + actualEventNames.append(event.name) + self.assertEqual(actualEventNames, wantEventNames) + + def assertSpanNames(self, want_span_names): + if not HAS_OPENTELEMETRY_INSTALLED: + return + + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + self.assertEqual(got_span_names, want_span_names) + + def get_finished_spans(self): + if HAS_OPENTELEMETRY_INSTALLED: + return list( + filter( + lambda span: span and span.name, + self.ot_exporter.get_finished_spans(), + ) + ) + else: + return [] diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 2f6b5e4ae9..a7f7a6f970 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -611,6 +611,10 @@ def __init__(self, database=None, name=TestBatch.SESSION_NAME): self._database = database self.name = name + @property + def session_id(self): + return self.name + class _Database(object): name = "testing" diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 90fa0c269f..6e29255fb7 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -3188,6 +3188,10 @@ def run_in_transaction(self, func, *args, **kw): self._retried = (func, args, kw) return self._committed + @property + def session_id(self): + return self.name + class _MockIterator(object): def __init__(self, *values, **kw): diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 2e3b46fa73..fbb35201eb 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -14,10 +14,17 @@ from functools import total_ordering +import time import unittest from datetime import datetime, timedelta import mock +from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from tests._helpers import ( + OpenTelemetryBase, + StatusCode, + enrich_with_otel_scope, +) def _make_database(name="name"): @@ -133,7 +140,15 @@ def test_session_w_kwargs(self): self.assertEqual(checkout._kwargs, {"foo": "bar"}) -class TestFixedSizePool(unittest.TestCase): +class TestFixedSizePool(OpenTelemetryBase): + BASE_ATTRIBUTES = { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": "name", + "net.host.name": "spanner.googleapis.com", + } + enrich_with_otel_scope(BASE_ATTRIBUTES) + def _getTargetClass(self): from google.cloud.spanner_v1.pool import FixedSizePool @@ -216,6 +231,93 @@ def test_get_non_expired(self): self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) + def test_spans_bind_get(self): + # This tests retrieving 1 out of 4 sessions from the session pool. + pool = self._make_one(size=4) + database = _Database("name") + SESSIONS = sorted([_Session(database) for i in range(0, 4)]) + database._sessions.extend(SESSIONS) + pool.bind(database) + + with trace_call("pool.Get", SESSIONS[0]) as span: + pool.get() + wantEventNames = [ + "Acquiring session", + "Waiting for a session to become available", + "Acquired session", + ] + self.assertSpanEvents("pool.Get", wantEventNames, span) + + # Check for the overall spans too. + self.assertSpanAttributes( + "pool.Get", + attributes=TestFixedSizePool.BASE_ATTRIBUTES, + ) + + wantEventNames = [ + "Acquiring session", + "Waiting for a session to become available", + "Acquired session", + ] + self.assertSpanEvents("pool.Get", wantEventNames) + + def test_spans_bind_get_empty_pool(self): + # Tests trying to invoke pool.get() from an empty pool. + pool = self._make_one(size=0) + database = _Database("name") + session1 = _Session(database) + with trace_call("pool.Get", session1): + try: + pool.bind(database) + database._sessions = database._sessions[:0] + pool.get() + except Exception: + pass + + wantEventNames = [ + "Invalid session pool size(0) <= 0", + "Acquiring session", + "Waiting for a session to become available", + "No sessions available in the pool", + ] + self.assertSpanEvents("pool.Get", wantEventNames) + + # Check for the overall spans too. + self.assertSpanNames(["pool.Get"]) + self.assertSpanAttributes( + "pool.Get", + attributes=TestFixedSizePool.BASE_ATTRIBUTES, + ) + + def test_spans_pool_bind(self): + # Tests the exception generated from invoking pool.bind when + # you have an empty pool. + pool = self._make_one(size=1) + database = _Database("name") + SESSIONS = [] + database._sessions.extend(SESSIONS) + fauxSession = mock.Mock() + setattr(fauxSession, "_database", database) + try: + with trace_call("testBind", fauxSession): + pool.bind(database) + except Exception: + pass + + wantEventNames = [ + "Requesting 1 sessions", + "Creating 1 sessions", + "exception", + ] + self.assertSpanEvents("testBind", wantEventNames) + + # Check for the overall spans. + self.assertSpanAttributes( + "testBind", + status=StatusCode.ERROR, + attributes=TestFixedSizePool.BASE_ATTRIBUTES, + ) + def test_get_expired(self): pool = self._make_one(size=4) database = _Database("name") @@ -299,7 +401,15 @@ def test_clear(self): self.assertTrue(session._deleted) -class TestBurstyPool(unittest.TestCase): +class TestBurstyPool(OpenTelemetryBase): + BASE_ATTRIBUTES = { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": "name", + "net.host.name": "spanner.googleapis.com", + } + enrich_with_otel_scope(BASE_ATTRIBUTES) + def _getTargetClass(self): from google.cloud.spanner_v1.pool import BurstyPool @@ -347,6 +457,34 @@ def test_get_empty(self): session.create.assert_called() self.assertTrue(pool._sessions.empty()) + def test_spans_get_empty_pool(self): + # This scenario tests a pool that hasn't been filled up + # and pool.get() acquires from a pool, waiting for a session + # to become available. + pool = self._make_one() + database = _Database("name") + session1 = _Session(database) + database._sessions.append(session1) + pool.bind(database) + + with trace_call("pool.Get", session1): + session = pool.get() + self.assertIsInstance(session, _Session) + self.assertIs(session._database, database) + session.create.assert_called() + self.assertTrue(pool._sessions.empty()) + + self.assertSpanAttributes( + "pool.Get", + attributes=TestBurstyPool.BASE_ATTRIBUTES, + ) + wantEventNames = [ + "Acquiring session", + "Waiting for a session to become available", + "No sessions available in pool. Creating session", + ] + self.assertSpanEvents("pool.Get", wantEventNames) + def test_get_non_empty_session_exists(self): pool = self._make_one() database = _Database("name") @@ -361,6 +499,30 @@ def test_get_non_empty_session_exists(self): self.assertTrue(session._exists_checked) self.assertTrue(pool._sessions.empty()) + def test_spans_get_non_empty_session_exists(self): + # Tests the spans produces when you invoke pool.bind + # and then insert a session into the pool. + pool = self._make_one() + database = _Database("name") + previous = _Session(database) + pool.bind(database) + with trace_call("pool.Get", previous): + pool.put(previous) + session = pool.get() + self.assertIs(session, previous) + session.create.assert_not_called() + self.assertTrue(session._exists_checked) + self.assertTrue(pool._sessions.empty()) + + self.assertSpanAttributes( + "pool.Get", + attributes=TestBurstyPool.BASE_ATTRIBUTES, + ) + self.assertSpanEvents( + "pool.Get", + ["Acquiring session", "Waiting for a session to become available"], + ) + def test_get_non_empty_session_expired(self): pool = self._make_one() database = _Database("name") @@ -388,6 +550,22 @@ def test_put_empty(self): self.assertFalse(pool._sessions.empty()) + def test_spans_put_empty(self): + # Tests the spans produced when you put sessions into an empty pool. + pool = self._make_one() + database = _Database("name") + pool.bind(database) + session = _Session(database) + + with trace_call("pool.put", session): + pool.put(session) + self.assertFalse(pool._sessions.empty()) + + self.assertSpanAttributes( + "pool.put", + attributes=TestBurstyPool.BASE_ATTRIBUTES, + ) + def test_put_full(self): pool = self._make_one(target_size=1) database = _Database("name") @@ -402,6 +580,28 @@ def test_put_full(self): self.assertTrue(younger._deleted) self.assertIs(pool.get(), older) + def test_spans_put_full(self): + # This scenario tests the spans produced from putting an older + # session into a pool that is already full. + pool = self._make_one(target_size=1) + database = _Database("name") + pool.bind(database) + older = _Session(database) + with trace_call("pool.put", older): + pool.put(older) + self.assertFalse(pool._sessions.empty()) + + younger = _Session(database) + pool.put(younger) # discarded silently + + self.assertTrue(younger._deleted) + self.assertIs(pool.get(), older) + + self.assertSpanAttributes( + "pool.put", + attributes=TestBurstyPool.BASE_ATTRIBUTES, + ) + def test_put_full_expired(self): pool = self._make_one(target_size=1) database = _Database("name") @@ -426,9 +626,18 @@ def test_clear(self): pool.clear() self.assertTrue(previous._deleted) + self.assertNoSpans() + +class TestPingingPool(OpenTelemetryBase): + BASE_ATTRIBUTES = { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": "name", + "net.host.name": "spanner.googleapis.com", + } + enrich_with_otel_scope(BASE_ATTRIBUTES) -class TestPingingPool(unittest.TestCase): def _getTargetClass(self): from google.cloud.spanner_v1.pool import PingingPool @@ -505,6 +714,7 @@ def test_get_hit_no_ping(self): self.assertIs(session, SESSIONS[0]) self.assertFalse(session._exists_checked) self.assertFalse(pool._sessions.full()) + self.assertNoSpans() def test_get_hit_w_ping(self): import datetime @@ -526,6 +736,7 @@ def test_get_hit_w_ping(self): self.assertIs(session, SESSIONS[0]) self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) + self.assertNoSpans() def test_get_hit_w_ping_expired(self): import datetime @@ -549,6 +760,7 @@ def test_get_hit_w_ping_expired(self): session.create.assert_called() self.assertTrue(SESSIONS[0]._exists_checked) self.assertFalse(pool._sessions.full()) + self.assertNoSpans() def test_get_empty_default_timeout(self): import queue @@ -560,6 +772,7 @@ def test_get_empty_default_timeout(self): pool.get() self.assertEqual(session_queue._got, {"block": True, "timeout": 10}) + self.assertNoSpans() def test_get_empty_explicit_timeout(self): import queue @@ -571,6 +784,7 @@ def test_get_empty_explicit_timeout(self): pool.get(timeout=1) self.assertEqual(session_queue._got, {"block": True, "timeout": 1}) + self.assertNoSpans() def test_put_full(self): import queue @@ -585,6 +799,7 @@ def test_put_full(self): pool.put(_Session(database)) self.assertTrue(pool._sessions.full()) + self.assertNoSpans() def test_put_non_full(self): import datetime @@ -605,6 +820,7 @@ def test_put_non_full(self): ping_after, queued = session_queue._items[0] self.assertEqual(ping_after, now + datetime.timedelta(seconds=3000)) self.assertIs(queued, session) + self.assertNoSpans() def test_clear(self): pool = self._make_one() @@ -623,10 +839,12 @@ def test_clear(self): for session in SESSIONS: self.assertTrue(session._deleted) + self.assertNoSpans() def test_ping_empty(self): pool = self._make_one(size=1) pool.ping() # Does not raise 'Empty' + self.assertNoSpans() def test_ping_oldest_fresh(self): pool = self._make_one(size=1) @@ -638,6 +856,7 @@ def test_ping_oldest_fresh(self): pool.ping() self.assertFalse(SESSIONS[0]._pinged) + self.assertNoSpans() def test_ping_oldest_stale_but_exists(self): import datetime @@ -674,193 +893,36 @@ def test_ping_oldest_stale_and_not_exists(self): self.assertTrue(SESSIONS[0]._pinged) SESSIONS[1].create.assert_called() + self.assertNoSpans() - -class TestTransactionPingingPool(unittest.TestCase): - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import TransactionPingingPool - - return TransactionPingingPool - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_defaults(self): - pool = self._make_one() - self.assertIsNone(pool._database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertEqual(pool._delta.seconds, 3000) - self.assertTrue(pool._sessions.empty()) - self.assertTrue(pool._pending_sessions.empty()) - self.assertEqual(pool.labels, {}) - self.assertIsNone(pool.database_role) - - def test_ctor_explicit(self): - labels = {"foo": "bar"} - database_role = "dummy-role" - pool = self._make_one( - size=4, - default_timeout=30, - ping_interval=1800, - labels=labels, - database_role=database_role, - ) - self.assertIsNone(pool._database) - self.assertEqual(pool.size, 4) - self.assertEqual(pool.default_timeout, 30) - self.assertEqual(pool._delta.seconds, 1800) - self.assertTrue(pool._sessions.empty()) - self.assertTrue(pool._pending_sessions.empty()) - self.assertEqual(pool.labels, labels) - self.assertEqual(pool.database_role, database_role) - - def test_ctor_explicit_w_database_role_in_db(self): - database_role = "dummy-role" - pool = self._make_one() - database = pool._database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) - database._database_role = database_role - pool.bind(database) - self.assertEqual(pool.database_role, database_role) - - def test_bind(self): + def test_spans_get_and_leave_empty_pool(self): + # This scenario tests the spans generated from pulling a span + # out the pool and leaving it empty. pool = self._make_one() database = _Database("name") - SESSIONS = [_Session(database) for _ in range(10)] - database._sessions.extend(SESSIONS) - pool.bind(database) - - self.assertIs(pool._database, database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertEqual(pool._delta.seconds, 3000) - self.assertTrue(pool._sessions.full()) - - api = database.spanner_api - self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: - session.create.assert_not_called() - txn = session._transaction - txn.begin.assert_not_called() - - self.assertTrue(pool._pending_sessions.empty()) - - def test_bind_w_timestamp_race(self): - import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT - - NOW = datetime.datetime.utcnow() - pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database) for _ in range(10)] - database._sessions.extend(SESSIONS) - - with _Monkey(MUT, _NOW=lambda: NOW): + session1 = _Session(database) + database._sessions.append(session1) + try: pool.bind(database) + except Exception: + pass - self.assertIs(pool._database, database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertEqual(pool._delta.seconds, 3000) - self.assertTrue(pool._sessions.full()) - - api = database.spanner_api - self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: - session.create.assert_not_called() - txn = session._transaction - txn.begin.assert_not_called() - - self.assertTrue(pool._pending_sessions.empty()) - - def test_put_full(self): - import queue - - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database) for _ in range(4)] - database._sessions.extend(SESSIONS) - pool.bind(database) - - with self.assertRaises(queue.Full): - pool.put(_Session(database)) - - self.assertTrue(pool._sessions.full()) - - def test_put_non_full_w_active_txn(self): - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - pending = pool._pending_sessions = _Queue() - database = _Database("name") - session = _Session(database) - txn = session.transaction() - - pool.put(session) - - self.assertEqual(len(session_queue._items), 1) - _, queued = session_queue._items[0] - self.assertIs(queued, session) - - self.assertEqual(len(pending._items), 0) - txn.begin.assert_not_called() - - def test_put_non_full_w_committed_txn(self): - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - pending = pool._pending_sessions = _Queue() - database = _Database("name") - session = _Session(database) - committed = session.transaction() - committed.committed = True - - pool.put(session) - - self.assertEqual(len(session_queue._items), 0) - - self.assertEqual(len(pending._items), 1) - self.assertIs(pending._items[0], session) - self.assertIsNot(session._transaction, committed) - session._transaction.begin.assert_not_called() - - def test_put_non_full(self): - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - pending = pool._pending_sessions = _Queue() - database = _Database("name") - session = _Session(database) - - pool.put(session) - - self.assertEqual(len(session_queue._items), 0) - self.assertEqual(len(pending._items), 1) - self.assertIs(pending._items[0], session) - - self.assertFalse(pending.empty()) - - def test_begin_pending_transactions_empty(self): - pool = self._make_one(size=1) - pool.begin_pending_transactions() # no raise - - def test_begin_pending_transactions_non_empty(self): - pool = self._make_one(size=1) - pool._sessions = _Queue() - - database = _Database("name") - TRANSACTIONS = [_make_transaction(object())] - PENDING_SESSIONS = [_Session(database, transaction=txn) for txn in TRANSACTIONS] - - pending = pool._pending_sessions = _Queue(*PENDING_SESSIONS) - self.assertFalse(pending.empty()) - - pool.begin_pending_transactions() # no raise - - for txn in TRANSACTIONS: - txn.begin.assert_not_called() - - self.assertTrue(pending.empty()) + with trace_call("pool.Get", session1): + session = pool.get() + self.assertIsInstance(session, _Session) + self.assertIs(session._database, database) + # session.create.assert_called() + self.assertTrue(pool._sessions.empty()) + + self.assertSpanAttributes( + "pool.Get", + attributes=TestPingingPool.BASE_ATTRIBUTES, + ) + wantEventNames = [ + "Waiting for a session to become available", + "Acquired session", + ] + self.assertSpanEvents("pool.Get", wantEventNames) class TestSessionCheckout(unittest.TestCase): @@ -945,6 +1007,8 @@ def __init__( self._deleted = False self._transaction = transaction self._last_use_time = last_use_time + # Generate a faux id. + self._session_id = f"{time.time()}" def __lt__(self, other): return id(self) < id(other) @@ -975,6 +1039,10 @@ def transaction(self): txn = self._transaction = _make_transaction(self) return txn + @property + def session_id(self): + return self._session_id + class _Database(object): def __init__(self, name): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 2ae0cb94b8..966adadcbd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -15,6 +15,7 @@ import google.api_core.gapic_v1.method from google.cloud.spanner_v1 import RequestOptions +from google.cloud.spanner_v1._opentelemetry_tracing import trace_call import mock from tests._helpers import ( OpenTelemetryBase, @@ -174,6 +175,43 @@ def test_create_w_database_role(self): "CloudSpanner.CreateSession", attributes=TestSession.BASE_ATTRIBUTES ) + def test_create_session_span_annotations(self): + from google.cloud.spanner_v1 import CreateSessionRequest + from google.cloud.spanner_v1 import Session as SessionRequestProto + + session_pb = self._make_session_pb( + self.SESSION_NAME, database_role=self.DATABASE_ROLE + ) + + gax_api = self._make_spanner_api() + gax_api.create_session.return_value = session_pb + database = self._make_database(database_role=self.DATABASE_ROLE) + database.spanner_api = gax_api + session = self._make_one(database, database_role=self.DATABASE_ROLE) + + with trace_call("TestSessionSpan", session) as span: + session.create() + + self.assertEqual(session.session_id, self.SESSION_ID) + self.assertEqual(session.database_role, self.DATABASE_ROLE) + session_template = SessionRequestProto(creator_role=self.DATABASE_ROLE) + + request = CreateSessionRequest( + database=database.name, + session=session_template, + ) + + gax_api.create_session.assert_called_once_with( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ], + ) + + wantEventNames = ["Creating Session"] + self.assertSpanEvents("TestSessionSpan", wantEventNames, span) + def test_create_wo_database_role(self): from google.cloud.spanner_v1 import CreateSessionRequest diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index bf7363fef2..479a0d62e9 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1822,6 +1822,10 @@ def __init__(self, database=None, name=TestSnapshot.SESSION_NAME): self._database = database self.name = name + @property + def session_id(self): + return self.name + class _MockIterator(object): def __init__(self, *values, **kw): diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index ab5479eb3c..ff34a109af 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -1082,6 +1082,10 @@ def __init__(self, database=None, name=TestTransaction.SESSION_NAME): self._database = database self.name = name + @property + def session_id(self): + return self.name + class _MockIterator(object): def __init__(self, *values, **kw): diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d52fb61db1..e426f912b2 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -939,6 +939,10 @@ def __init__(self, database=None, name=TestTransaction.SESSION_NAME): self._database = database self.name = name + @property + def session_id(self): + return self.name + class _FauxSpannerAPI(object): _committed = None