diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index a4d66fc20f..29bd604e7b 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -463,6 +463,7 @@ def _retry( retry_count=5, delay=2, allowed_exceptions=None, + beforeNextRetry=None, ): """ Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions. @@ -479,6 +480,9 @@ def _retry( """ retries = 0 while retries <= retry_count: + if retries > 0 and beforeNextRetry: + beforeNextRetry(retries, delay) + try: return func() except Exception as exc: diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index e5aad08c05..1caac59ecd 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -81,10 +81,11 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) tracer = get_tracer(tracer_provider) # Set base attributes that we know for every trace created + db = session._database attributes = { "db.type": "spanner", "db.url": SpannerClient.DEFAULT_ENDPOINT, - "db.instance": session._database.name, + "db.instance": "" if not db else db.name, "net.host.name": SpannerClient.DEFAULT_ENDPOINT, OTEL_SCOPE_NAME: TRACER_NAME, OTEL_SCOPE_VERSION: TRACER_VERSION, @@ -106,7 +107,10 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) yield span except Exception as error: span.set_status(Status(StatusCode.ERROR, str(error))) - span.record_exception(error) + # OpenTelemetry-Python imposes invoking span.record_exception on __exit__ + # on any exception. We should file a bug later on with them to only + # invoke .record_exception if not already invoked, hence we should not + # invoke .record_exception on our own else we shall have 2 exceptions. raise else: if (not span._status) or span._status.status_code == StatusCode.UNSET: @@ -116,3 +120,14 @@ def trace_call(name, session, extra_attributes=None, observability_options=None) # it wasn't previously set otherwise. # https://github.com/googleapis/python-spanner/issues/1246 span.set_status(Status(StatusCode.OK)) + + +def get_current_span(): + if not HAS_OPENTELEMETRY_INSTALLED: + return None + return trace.get_current_span() + + +def add_span_event(span, event_name, event_attributes=None): + if span: + span.add_event(event_name, event_attributes) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 1e10e1df73..c8230ab503 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -67,6 +67,10 @@ SpannerGrpcTransport, ) from google.cloud.spanner_v1.table import Table +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_span_event, + get_current_span, +) SPANNER_DATA_SCOPE = "https://www.googleapis.com/auth/spanner.data" @@ -1164,7 +1168,9 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" + current_span = get_current_span() session = self._session = self._database._pool.get() + add_span_event(current_span, "Using session", {"id": session.session_id}) batch = self._batch = Batch(session) if self._request_options.transaction_tag: batch.transaction_tag = self._request_options.transaction_tag @@ -1187,6 +1193,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): extra={"commit_stats": self._batch.commit_stats}, ) self._database._pool.put(self._session) + current_span = get_current_span() + add_span_event( + current_span, + "Returned session to pool", + {"id": self._session.session_id}, + ) class MutationGroupsCheckout(object): diff --git a/google/cloud/spanner_v1/pool.py b/google/cloud/spanner_v1/pool.py index c95ef7a7b9..4f90196b4a 100644 --- a/google/cloud/spanner_v1/pool.py +++ b/google/cloud/spanner_v1/pool.py @@ -16,6 +16,7 @@ import datetime import queue +import time from google.cloud.exceptions import NotFound from google.cloud.spanner_v1 import BatchCreateSessionsRequest @@ -24,6 +25,10 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_span_event, + get_current_span, +) from warnings import warn _NOW = datetime.datetime.utcnow # unit tests may replace @@ -196,6 +201,18 @@ def bind(self, database): when needed. """ self._database = database + requested_session_count = self.size - self._sessions.qsize() + span = get_current_span() + span_event_attributes = {"kind": type(self).__name__} + + if requested_session_count <= 0: + add_span_event( + span, + f"Invalid session pool size({requested_session_count}) <= 0", + span_event_attributes, + ) + return + api = database.spanner_api metadata = _metadata_with_prefix(database.name) if database._route_to_leader_enabled: @@ -203,13 +220,31 @@ def bind(self, database): _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) self._database_role = self._database_role or self._database.database_role + if requested_session_count > 0: + add_span_event( + span, + f"Requesting {requested_session_count} sessions", + span_event_attributes, + ) + + if self._sessions.full(): + add_span_event(span, "Session pool is already full", span_event_attributes) + return + request = BatchCreateSessionsRequest( database=database.name, - session_count=self.size - self._sessions.qsize(), + session_count=requested_session_count, session_template=Session(creator_role=self.database_role), ) + returned_session_count = 0 while not self._sessions.full(): + request.session_count = requested_session_count - self._sessions.qsize() + add_span_event( + span, + f"Creating {request.session_count} sessions", + span_event_attributes, + ) resp = api.batch_create_sessions( request=request, metadata=metadata, @@ -218,6 +253,13 @@ def bind(self, database): session = self._new_session() session._session_id = session_pb.name.split("/")[-1] self._sessions.put(session) + returned_session_count += 1 + + add_span_event( + span, + f"Requested for {requested_session_count} sessions, returned {returned_session_count}", + span_event_attributes, + ) def get(self, timeout=None): """Check a session out from the pool. @@ -233,12 +275,43 @@ def get(self, timeout=None): if timeout is None: timeout = self.default_timeout - session = self._sessions.get(block=True, timeout=timeout) - age = _NOW() - session.last_use_time + start_time = time.time() + current_span = get_current_span() + span_event_attributes = {"kind": type(self).__name__} + add_span_event(current_span, "Acquiring session", span_event_attributes) - if age >= self._max_age and not session.exists(): - session = self._database.session() - session.create() + session = None + try: + add_span_event( + current_span, + "Waiting for a session to become available", + span_event_attributes, + ) + + session = self._sessions.get(block=True, timeout=timeout) + age = _NOW() - session.last_use_time + + if age >= self._max_age and not session.exists(): + if not session.exists(): + add_span_event( + current_span, + "Session is not valid, recreating it", + span_event_attributes, + ) + session = self._database.session() + session.create() + # Replacing with the updated session.id. + span_event_attributes["session.id"] = session._session_id + + span_event_attributes["session.id"] = session._session_id + span_event_attributes["time.elapsed"] = time.time() - start_time + add_span_event(current_span, "Acquired session", span_event_attributes) + + except queue.Empty as e: + add_span_event( + current_span, "No sessions available in the pool", span_event_attributes + ) + raise e return session @@ -312,13 +385,32 @@ def get(self): :returns: an existing session from the pool, or a newly-created session. """ + current_span = get_current_span() + span_event_attributes = {"kind": type(self).__name__} + add_span_event(current_span, "Acquiring session", span_event_attributes) + try: + add_span_event( + current_span, + "Waiting for a session to become available", + span_event_attributes, + ) session = self._sessions.get_nowait() except queue.Empty: + add_span_event( + current_span, + "No sessions available in pool. Creating session", + span_event_attributes, + ) session = self._new_session() session.create() else: if not session.exists(): + add_span_event( + current_span, + "Session is not valid, recreating it", + span_event_attributes, + ) session = self._new_session() session.create() return session @@ -427,6 +519,38 @@ def bind(self, database): session_template=Session(creator_role=self.database_role), ) + span_event_attributes = {"kind": type(self).__name__} + current_span = get_current_span() + requested_session_count = request.session_count + if requested_session_count <= 0: + add_span_event( + current_span, + f"Invalid session pool size({requested_session_count}) <= 0", + span_event_attributes, + ) + return + + add_span_event( + current_span, + f"Requesting {requested_session_count} sessions", + span_event_attributes, + ) + + if created_session_count >= self.size: + add_span_event( + current_span, + "Created no new sessions as sessionPool is full", + span_event_attributes, + ) + return + + add_span_event( + current_span, + f"Creating {request.session_count} sessions", + span_event_attributes, + ) + + returned_session_count = 0 while created_session_count < self.size: resp = api.batch_create_sessions( request=request, @@ -436,8 +560,16 @@ def bind(self, database): session = self._new_session() session._session_id = session_pb.name.split("/")[-1] self.put(session) + returned_session_count += 1 + created_session_count += len(resp.session) + add_span_event( + current_span, + f"Requested for {requested_session_count} sessions, return {returned_session_count}", + span_event_attributes, + ) + def get(self, timeout=None): """Check a session out from the pool. @@ -452,7 +584,26 @@ def get(self, timeout=None): if timeout is None: timeout = self.default_timeout - ping_after, session = self._sessions.get(block=True, timeout=timeout) + start_time = time.time() + span_event_attributes = {"kind": type(self).__name__} + current_span = get_current_span() + add_span_event( + current_span, + "Waiting for a session to become available", + span_event_attributes, + ) + + ping_after = None + session = None + try: + ping_after, session = self._sessions.get(block=True, timeout=timeout) + except queue.Empty as e: + add_span_event( + current_span, + "No sessions available in the pool within the specified timeout", + span_event_attributes, + ) + raise e if _NOW() > ping_after: # Using session.exists() guarantees the returned session exists. @@ -462,6 +613,14 @@ def get(self, timeout=None): session = self._new_session() session.create() + span_event_attributes.update( + { + "time.elapsed": time.time() - start_time, + "session.id": session._session_id, + "kind": "pinging_pool", + } + ) + add_span_event(current_span, "Acquired session", span_event_attributes) return session def put(self, session): diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 539f36af2b..166d5488c6 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -31,7 +31,11 @@ _metadata_with_prefix, _metadata_with_leader_aware_routing, ) -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_span_event, + get_current_span, + trace_call, +) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.snapshot import Snapshot from google.cloud.spanner_v1.transaction import Transaction @@ -134,6 +138,9 @@ def create(self): :raises ValueError: if :attr:`session_id` is already set. """ + current_span = get_current_span() + add_span_event(current_span, "Creating Session") + if self._session_id is not None: raise ValueError("Session ID already set by back-end") api = self._database.spanner_api @@ -174,8 +181,18 @@ def exists(self): :rtype: bool :returns: True if the session exists on the back-end, else False. """ + current_span = get_current_span() if self._session_id is None: + add_span_event( + current_span, + "Checking session existence: Session does not exist as it has not been created yet", + ) return False + + add_span_event( + current_span, "Checking if Session exists", {"session.id": self._session_id} + ) + api = self._database.spanner_api metadata = _metadata_with_prefix(self._database.name) if self._database._route_to_leader_enabled: @@ -209,8 +226,17 @@ def delete(self): :raises ValueError: if :attr:`session_id` is not already set. :raises NotFound: if the session does not exist """ + current_span = get_current_span() if self._session_id is None: + add_span_event( + current_span, "Deleting Session failed due to unset session_id" + ) raise ValueError("Session ID not set by back-end") + + add_span_event( + current_span, "Deleting Session", {"session.id": self._session_id} + ) + api = self._database.spanner_api metadata = _metadata_with_prefix(self._database.name) observability_options = getattr(self._database, "observability_options", None) diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index d99c4fde2f..fa8e5121ff 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -32,7 +32,7 @@ from google.cloud.spanner_v1 import TransactionOptions from google.cloud.spanner_v1.snapshot import _SnapshotBase from google.cloud.spanner_v1.batch import _BatchBase -from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from google.cloud.spanner_v1._opentelemetry_tracing import add_span_event, trace_call from google.cloud.spanner_v1 import RequestOptions from google.api_core import gapic_v1 from google.api_core.exceptions import InternalServerError @@ -160,16 +160,25 @@ def begin(self): "CloudSpanner.BeginTransaction", self._session, observability_options=observability_options, - ): + ) as span: method = functools.partial( api.begin_transaction, session=self._session.name, options=txn_options, metadata=metadata, ) + + def beforeNextRetry(nthRetry, delayInSeconds): + add_span_event( + span, + "Transaction Begin Attempt Failed. Retrying", + {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + ) + response = _retry( method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, + beforeNextRetry=beforeNextRetry, ) self._transaction_id = response.id return self._transaction_id @@ -246,7 +255,6 @@ def commit( metadata.append( _metadata_with_leader_aware_routing(database._route_to_leader_enabled) ) - trace_attributes = {"num_mutations": len(self._mutations)} if request_options is None: request_options = RequestOptions() @@ -266,22 +274,38 @@ def commit( max_commit_delay=max_commit_delay, request_options=request_options, ) + + trace_attributes = {"num_mutations": len(self._mutations)} observability_options = getattr(database, "observability_options", None) with trace_call( "CloudSpanner.Commit", self._session, trace_attributes, observability_options, - ): + ) as span: + add_span_event(span, "Starting Commit") + method = functools.partial( api.commit, request=request, metadata=metadata, ) + + def beforeNextRetry(nthRetry, delayInSeconds): + add_span_event( + span, + "Transaction Commit Attempt Failed. Retrying", + {"attempt": nthRetry, "sleep_seconds": delayInSeconds}, + ) + response = _retry( method, allowed_exceptions={InternalServerError: _check_rst_stream_error}, + beforeNextRetry=beforeNextRetry, ) + + add_span_event(span, "Commit Done") + self.committed = response.commit_timestamp if return_commit_stats: self.commit_stats = response.commit_stats diff --git a/tests/_helpers.py b/tests/_helpers.py index 5e514f2586..81787c5a86 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -16,10 +16,11 @@ OTEL_SCOPE_NAME, OTEL_SCOPE_VERSION, ) + from opentelemetry.sdk.trace.sampling import TraceIdRatioBased from opentelemetry.trace.status import StatusCode - trace.set_tracer_provider(TracerProvider()) + trace.set_tracer_provider(TracerProvider(sampler=TraceIdRatioBased(1.0))) HAS_OPENTELEMETRY_INSTALLED = True except ImportError: @@ -86,9 +87,43 @@ def assertSpanAttributes( if HAS_OPENTELEMETRY_INSTALLED: if not span: span_list = self.ot_exporter.get_finished_spans() - self.assertEqual(len(span_list), 1) + self.assertEqual(len(span_list) > 0, True) span = span_list[0] self.assertEqual(span.name, name) self.assertEqual(span.status.status_code, status) self.assertEqual(dict(span.attributes), attributes) + + def assertSpanEvents(self, name, wantEventNames=[], span=None): + if not HAS_OPENTELEMETRY_INSTALLED: + return + + if not span: + span_list = self.ot_exporter.get_finished_spans() + self.assertEqual(len(span_list) > 0, True) + span = span_list[0] + + self.assertEqual(span.name, name) + actualEventNames = [] + for event in span.events: + actualEventNames.append(event.name) + self.assertEqual(actualEventNames, wantEventNames) + + def assertSpanNames(self, want_span_names): + if not HAS_OPENTELEMETRY_INSTALLED: + return + + span_list = self.get_finished_spans() + got_span_names = [span.name for span in span_list] + self.assertEqual(got_span_names, want_span_names) + + def get_finished_spans(self): + if HAS_OPENTELEMETRY_INSTALLED: + return list( + filter( + lambda span: span and span.name, + self.ot_exporter.get_finished_spans(), + ) + ) + else: + return [] diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 2f6b5e4ae9..a7f7a6f970 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -611,6 +611,10 @@ def __init__(self, database=None, name=TestBatch.SESSION_NAME): self._database = database self.name = name + @property + def session_id(self): + return self.name + class _Database(object): name = "testing" diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 90fa0c269f..6e29255fb7 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -3188,6 +3188,10 @@ def run_in_transaction(self, func, *args, **kw): self._retried = (func, args, kw) return self._committed + @property + def session_id(self): + return self.name + class _MockIterator(object): def __init__(self, *values, **kw): diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index 2e3b46fa73..fbb35201eb 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -14,10 +14,17 @@ from functools import total_ordering +import time import unittest from datetime import datetime, timedelta import mock +from google.cloud.spanner_v1._opentelemetry_tracing import trace_call +from tests._helpers import ( + OpenTelemetryBase, + StatusCode, + enrich_with_otel_scope, +) def _make_database(name="name"): @@ -133,7 +140,15 @@ def test_session_w_kwargs(self): self.assertEqual(checkout._kwargs, {"foo": "bar"}) -class TestFixedSizePool(unittest.TestCase): +class TestFixedSizePool(OpenTelemetryBase): + BASE_ATTRIBUTES = { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": "name", + "net.host.name": "spanner.googleapis.com", + } + enrich_with_otel_scope(BASE_ATTRIBUTES) + def _getTargetClass(self): from google.cloud.spanner_v1.pool import FixedSizePool @@ -216,6 +231,93 @@ def test_get_non_expired(self): self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) + def test_spans_bind_get(self): + # This tests retrieving 1 out of 4 sessions from the session pool. + pool = self._make_one(size=4) + database = _Database("name") + SESSIONS = sorted([_Session(database) for i in range(0, 4)]) + database._sessions.extend(SESSIONS) + pool.bind(database) + + with trace_call("pool.Get", SESSIONS[0]) as span: + pool.get() + wantEventNames = [ + "Acquiring session", + "Waiting for a session to become available", + "Acquired session", + ] + self.assertSpanEvents("pool.Get", wantEventNames, span) + + # Check for the overall spans too. + self.assertSpanAttributes( + "pool.Get", + attributes=TestFixedSizePool.BASE_ATTRIBUTES, + ) + + wantEventNames = [ + "Acquiring session", + "Waiting for a session to become available", + "Acquired session", + ] + self.assertSpanEvents("pool.Get", wantEventNames) + + def test_spans_bind_get_empty_pool(self): + # Tests trying to invoke pool.get() from an empty pool. + pool = self._make_one(size=0) + database = _Database("name") + session1 = _Session(database) + with trace_call("pool.Get", session1): + try: + pool.bind(database) + database._sessions = database._sessions[:0] + pool.get() + except Exception: + pass + + wantEventNames = [ + "Invalid session pool size(0) <= 0", + "Acquiring session", + "Waiting for a session to become available", + "No sessions available in the pool", + ] + self.assertSpanEvents("pool.Get", wantEventNames) + + # Check for the overall spans too. + self.assertSpanNames(["pool.Get"]) + self.assertSpanAttributes( + "pool.Get", + attributes=TestFixedSizePool.BASE_ATTRIBUTES, + ) + + def test_spans_pool_bind(self): + # Tests the exception generated from invoking pool.bind when + # you have an empty pool. + pool = self._make_one(size=1) + database = _Database("name") + SESSIONS = [] + database._sessions.extend(SESSIONS) + fauxSession = mock.Mock() + setattr(fauxSession, "_database", database) + try: + with trace_call("testBind", fauxSession): + pool.bind(database) + except Exception: + pass + + wantEventNames = [ + "Requesting 1 sessions", + "Creating 1 sessions", + "exception", + ] + self.assertSpanEvents("testBind", wantEventNames) + + # Check for the overall spans. + self.assertSpanAttributes( + "testBind", + status=StatusCode.ERROR, + attributes=TestFixedSizePool.BASE_ATTRIBUTES, + ) + def test_get_expired(self): pool = self._make_one(size=4) database = _Database("name") @@ -299,7 +401,15 @@ def test_clear(self): self.assertTrue(session._deleted) -class TestBurstyPool(unittest.TestCase): +class TestBurstyPool(OpenTelemetryBase): + BASE_ATTRIBUTES = { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": "name", + "net.host.name": "spanner.googleapis.com", + } + enrich_with_otel_scope(BASE_ATTRIBUTES) + def _getTargetClass(self): from google.cloud.spanner_v1.pool import BurstyPool @@ -347,6 +457,34 @@ def test_get_empty(self): session.create.assert_called() self.assertTrue(pool._sessions.empty()) + def test_spans_get_empty_pool(self): + # This scenario tests a pool that hasn't been filled up + # and pool.get() acquires from a pool, waiting for a session + # to become available. + pool = self._make_one() + database = _Database("name") + session1 = _Session(database) + database._sessions.append(session1) + pool.bind(database) + + with trace_call("pool.Get", session1): + session = pool.get() + self.assertIsInstance(session, _Session) + self.assertIs(session._database, database) + session.create.assert_called() + self.assertTrue(pool._sessions.empty()) + + self.assertSpanAttributes( + "pool.Get", + attributes=TestBurstyPool.BASE_ATTRIBUTES, + ) + wantEventNames = [ + "Acquiring session", + "Waiting for a session to become available", + "No sessions available in pool. Creating session", + ] + self.assertSpanEvents("pool.Get", wantEventNames) + def test_get_non_empty_session_exists(self): pool = self._make_one() database = _Database("name") @@ -361,6 +499,30 @@ def test_get_non_empty_session_exists(self): self.assertTrue(session._exists_checked) self.assertTrue(pool._sessions.empty()) + def test_spans_get_non_empty_session_exists(self): + # Tests the spans produces when you invoke pool.bind + # and then insert a session into the pool. + pool = self._make_one() + database = _Database("name") + previous = _Session(database) + pool.bind(database) + with trace_call("pool.Get", previous): + pool.put(previous) + session = pool.get() + self.assertIs(session, previous) + session.create.assert_not_called() + self.assertTrue(session._exists_checked) + self.assertTrue(pool._sessions.empty()) + + self.assertSpanAttributes( + "pool.Get", + attributes=TestBurstyPool.BASE_ATTRIBUTES, + ) + self.assertSpanEvents( + "pool.Get", + ["Acquiring session", "Waiting for a session to become available"], + ) + def test_get_non_empty_session_expired(self): pool = self._make_one() database = _Database("name") @@ -388,6 +550,22 @@ def test_put_empty(self): self.assertFalse(pool._sessions.empty()) + def test_spans_put_empty(self): + # Tests the spans produced when you put sessions into an empty pool. + pool = self._make_one() + database = _Database("name") + pool.bind(database) + session = _Session(database) + + with trace_call("pool.put", session): + pool.put(session) + self.assertFalse(pool._sessions.empty()) + + self.assertSpanAttributes( + "pool.put", + attributes=TestBurstyPool.BASE_ATTRIBUTES, + ) + def test_put_full(self): pool = self._make_one(target_size=1) database = _Database("name") @@ -402,6 +580,28 @@ def test_put_full(self): self.assertTrue(younger._deleted) self.assertIs(pool.get(), older) + def test_spans_put_full(self): + # This scenario tests the spans produced from putting an older + # session into a pool that is already full. + pool = self._make_one(target_size=1) + database = _Database("name") + pool.bind(database) + older = _Session(database) + with trace_call("pool.put", older): + pool.put(older) + self.assertFalse(pool._sessions.empty()) + + younger = _Session(database) + pool.put(younger) # discarded silently + + self.assertTrue(younger._deleted) + self.assertIs(pool.get(), older) + + self.assertSpanAttributes( + "pool.put", + attributes=TestBurstyPool.BASE_ATTRIBUTES, + ) + def test_put_full_expired(self): pool = self._make_one(target_size=1) database = _Database("name") @@ -426,9 +626,18 @@ def test_clear(self): pool.clear() self.assertTrue(previous._deleted) + self.assertNoSpans() + +class TestPingingPool(OpenTelemetryBase): + BASE_ATTRIBUTES = { + "db.type": "spanner", + "db.url": "spanner.googleapis.com", + "db.instance": "name", + "net.host.name": "spanner.googleapis.com", + } + enrich_with_otel_scope(BASE_ATTRIBUTES) -class TestPingingPool(unittest.TestCase): def _getTargetClass(self): from google.cloud.spanner_v1.pool import PingingPool @@ -505,6 +714,7 @@ def test_get_hit_no_ping(self): self.assertIs(session, SESSIONS[0]) self.assertFalse(session._exists_checked) self.assertFalse(pool._sessions.full()) + self.assertNoSpans() def test_get_hit_w_ping(self): import datetime @@ -526,6 +736,7 @@ def test_get_hit_w_ping(self): self.assertIs(session, SESSIONS[0]) self.assertTrue(session._exists_checked) self.assertFalse(pool._sessions.full()) + self.assertNoSpans() def test_get_hit_w_ping_expired(self): import datetime @@ -549,6 +760,7 @@ def test_get_hit_w_ping_expired(self): session.create.assert_called() self.assertTrue(SESSIONS[0]._exists_checked) self.assertFalse(pool._sessions.full()) + self.assertNoSpans() def test_get_empty_default_timeout(self): import queue @@ -560,6 +772,7 @@ def test_get_empty_default_timeout(self): pool.get() self.assertEqual(session_queue._got, {"block": True, "timeout": 10}) + self.assertNoSpans() def test_get_empty_explicit_timeout(self): import queue @@ -571,6 +784,7 @@ def test_get_empty_explicit_timeout(self): pool.get(timeout=1) self.assertEqual(session_queue._got, {"block": True, "timeout": 1}) + self.assertNoSpans() def test_put_full(self): import queue @@ -585,6 +799,7 @@ def test_put_full(self): pool.put(_Session(database)) self.assertTrue(pool._sessions.full()) + self.assertNoSpans() def test_put_non_full(self): import datetime @@ -605,6 +820,7 @@ def test_put_non_full(self): ping_after, queued = session_queue._items[0] self.assertEqual(ping_after, now + datetime.timedelta(seconds=3000)) self.assertIs(queued, session) + self.assertNoSpans() def test_clear(self): pool = self._make_one() @@ -623,10 +839,12 @@ def test_clear(self): for session in SESSIONS: self.assertTrue(session._deleted) + self.assertNoSpans() def test_ping_empty(self): pool = self._make_one(size=1) pool.ping() # Does not raise 'Empty' + self.assertNoSpans() def test_ping_oldest_fresh(self): pool = self._make_one(size=1) @@ -638,6 +856,7 @@ def test_ping_oldest_fresh(self): pool.ping() self.assertFalse(SESSIONS[0]._pinged) + self.assertNoSpans() def test_ping_oldest_stale_but_exists(self): import datetime @@ -674,193 +893,36 @@ def test_ping_oldest_stale_and_not_exists(self): self.assertTrue(SESSIONS[0]._pinged) SESSIONS[1].create.assert_called() + self.assertNoSpans() - -class TestTransactionPingingPool(unittest.TestCase): - def _getTargetClass(self): - from google.cloud.spanner_v1.pool import TransactionPingingPool - - return TransactionPingingPool - - def _make_one(self, *args, **kwargs): - return self._getTargetClass()(*args, **kwargs) - - def test_ctor_defaults(self): - pool = self._make_one() - self.assertIsNone(pool._database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertEqual(pool._delta.seconds, 3000) - self.assertTrue(pool._sessions.empty()) - self.assertTrue(pool._pending_sessions.empty()) - self.assertEqual(pool.labels, {}) - self.assertIsNone(pool.database_role) - - def test_ctor_explicit(self): - labels = {"foo": "bar"} - database_role = "dummy-role" - pool = self._make_one( - size=4, - default_timeout=30, - ping_interval=1800, - labels=labels, - database_role=database_role, - ) - self.assertIsNone(pool._database) - self.assertEqual(pool.size, 4) - self.assertEqual(pool.default_timeout, 30) - self.assertEqual(pool._delta.seconds, 1800) - self.assertTrue(pool._sessions.empty()) - self.assertTrue(pool._pending_sessions.empty()) - self.assertEqual(pool.labels, labels) - self.assertEqual(pool.database_role, database_role) - - def test_ctor_explicit_w_database_role_in_db(self): - database_role = "dummy-role" - pool = self._make_one() - database = pool._database = _Database("name") - SESSIONS = [_Session(database)] * 10 - database._sessions.extend(SESSIONS) - database._database_role = database_role - pool.bind(database) - self.assertEqual(pool.database_role, database_role) - - def test_bind(self): + def test_spans_get_and_leave_empty_pool(self): + # This scenario tests the spans generated from pulling a span + # out the pool and leaving it empty. pool = self._make_one() database = _Database("name") - SESSIONS = [_Session(database) for _ in range(10)] - database._sessions.extend(SESSIONS) - pool.bind(database) - - self.assertIs(pool._database, database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertEqual(pool._delta.seconds, 3000) - self.assertTrue(pool._sessions.full()) - - api = database.spanner_api - self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: - session.create.assert_not_called() - txn = session._transaction - txn.begin.assert_not_called() - - self.assertTrue(pool._pending_sessions.empty()) - - def test_bind_w_timestamp_race(self): - import datetime - from google.cloud._testing import _Monkey - from google.cloud.spanner_v1 import pool as MUT - - NOW = datetime.datetime.utcnow() - pool = self._make_one() - database = _Database("name") - SESSIONS = [_Session(database) for _ in range(10)] - database._sessions.extend(SESSIONS) - - with _Monkey(MUT, _NOW=lambda: NOW): + session1 = _Session(database) + database._sessions.append(session1) + try: pool.bind(database) + except Exception: + pass - self.assertIs(pool._database, database) - self.assertEqual(pool.size, 10) - self.assertEqual(pool.default_timeout, 10) - self.assertEqual(pool._delta.seconds, 3000) - self.assertTrue(pool._sessions.full()) - - api = database.spanner_api - self.assertEqual(api.batch_create_sessions.call_count, 5) - for session in SESSIONS: - session.create.assert_not_called() - txn = session._transaction - txn.begin.assert_not_called() - - self.assertTrue(pool._pending_sessions.empty()) - - def test_put_full(self): - import queue - - pool = self._make_one(size=4) - database = _Database("name") - SESSIONS = [_Session(database) for _ in range(4)] - database._sessions.extend(SESSIONS) - pool.bind(database) - - with self.assertRaises(queue.Full): - pool.put(_Session(database)) - - self.assertTrue(pool._sessions.full()) - - def test_put_non_full_w_active_txn(self): - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - pending = pool._pending_sessions = _Queue() - database = _Database("name") - session = _Session(database) - txn = session.transaction() - - pool.put(session) - - self.assertEqual(len(session_queue._items), 1) - _, queued = session_queue._items[0] - self.assertIs(queued, session) - - self.assertEqual(len(pending._items), 0) - txn.begin.assert_not_called() - - def test_put_non_full_w_committed_txn(self): - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - pending = pool._pending_sessions = _Queue() - database = _Database("name") - session = _Session(database) - committed = session.transaction() - committed.committed = True - - pool.put(session) - - self.assertEqual(len(session_queue._items), 0) - - self.assertEqual(len(pending._items), 1) - self.assertIs(pending._items[0], session) - self.assertIsNot(session._transaction, committed) - session._transaction.begin.assert_not_called() - - def test_put_non_full(self): - pool = self._make_one(size=1) - session_queue = pool._sessions = _Queue() - pending = pool._pending_sessions = _Queue() - database = _Database("name") - session = _Session(database) - - pool.put(session) - - self.assertEqual(len(session_queue._items), 0) - self.assertEqual(len(pending._items), 1) - self.assertIs(pending._items[0], session) - - self.assertFalse(pending.empty()) - - def test_begin_pending_transactions_empty(self): - pool = self._make_one(size=1) - pool.begin_pending_transactions() # no raise - - def test_begin_pending_transactions_non_empty(self): - pool = self._make_one(size=1) - pool._sessions = _Queue() - - database = _Database("name") - TRANSACTIONS = [_make_transaction(object())] - PENDING_SESSIONS = [_Session(database, transaction=txn) for txn in TRANSACTIONS] - - pending = pool._pending_sessions = _Queue(*PENDING_SESSIONS) - self.assertFalse(pending.empty()) - - pool.begin_pending_transactions() # no raise - - for txn in TRANSACTIONS: - txn.begin.assert_not_called() - - self.assertTrue(pending.empty()) + with trace_call("pool.Get", session1): + session = pool.get() + self.assertIsInstance(session, _Session) + self.assertIs(session._database, database) + # session.create.assert_called() + self.assertTrue(pool._sessions.empty()) + + self.assertSpanAttributes( + "pool.Get", + attributes=TestPingingPool.BASE_ATTRIBUTES, + ) + wantEventNames = [ + "Waiting for a session to become available", + "Acquired session", + ] + self.assertSpanEvents("pool.Get", wantEventNames) class TestSessionCheckout(unittest.TestCase): @@ -945,6 +1007,8 @@ def __init__( self._deleted = False self._transaction = transaction self._last_use_time = last_use_time + # Generate a faux id. + self._session_id = f"{time.time()}" def __lt__(self, other): return id(self) < id(other) @@ -975,6 +1039,10 @@ def transaction(self): txn = self._transaction = _make_transaction(self) return txn + @property + def session_id(self): + return self._session_id + class _Database(object): def __init__(self, name): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 2ae0cb94b8..966adadcbd 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -15,6 +15,7 @@ import google.api_core.gapic_v1.method from google.cloud.spanner_v1 import RequestOptions +from google.cloud.spanner_v1._opentelemetry_tracing import trace_call import mock from tests._helpers import ( OpenTelemetryBase, @@ -174,6 +175,43 @@ def test_create_w_database_role(self): "CloudSpanner.CreateSession", attributes=TestSession.BASE_ATTRIBUTES ) + def test_create_session_span_annotations(self): + from google.cloud.spanner_v1 import CreateSessionRequest + from google.cloud.spanner_v1 import Session as SessionRequestProto + + session_pb = self._make_session_pb( + self.SESSION_NAME, database_role=self.DATABASE_ROLE + ) + + gax_api = self._make_spanner_api() + gax_api.create_session.return_value = session_pb + database = self._make_database(database_role=self.DATABASE_ROLE) + database.spanner_api = gax_api + session = self._make_one(database, database_role=self.DATABASE_ROLE) + + with trace_call("TestSessionSpan", session) as span: + session.create() + + self.assertEqual(session.session_id, self.SESSION_ID) + self.assertEqual(session.database_role, self.DATABASE_ROLE) + session_template = SessionRequestProto(creator_role=self.DATABASE_ROLE) + + request = CreateSessionRequest( + database=database.name, + session=session_template, + ) + + gax_api.create_session.assert_called_once_with( + request=request, + metadata=[ + ("google-cloud-resource-prefix", database.name), + ("x-goog-spanner-route-to-leader", "true"), + ], + ) + + wantEventNames = ["Creating Session"] + self.assertSpanEvents("TestSessionSpan", wantEventNames, span) + def test_create_wo_database_role(self): from google.cloud.spanner_v1 import CreateSessionRequest diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index bf7363fef2..479a0d62e9 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -1822,6 +1822,10 @@ def __init__(self, database=None, name=TestSnapshot.SESSION_NAME): self._database = database self.name = name + @property + def session_id(self): + return self.name + class _MockIterator(object): def __init__(self, *values, **kw): diff --git a/tests/unit/test_spanner.py b/tests/unit/test_spanner.py index ab5479eb3c..ff34a109af 100644 --- a/tests/unit/test_spanner.py +++ b/tests/unit/test_spanner.py @@ -1082,6 +1082,10 @@ def __init__(self, database=None, name=TestTransaction.SESSION_NAME): self._database = database self.name = name + @property + def session_id(self): + return self.name + class _MockIterator(object): def __init__(self, *values, **kw): diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d52fb61db1..e426f912b2 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -939,6 +939,10 @@ def __init__(self, database=None, name=TestTransaction.SESSION_NAME): self._database = database self.name = name + @property + def session_id(self): + return self.name + class _FauxSpannerAPI(object): _committed = None