diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 79a501ddc3..46379194ee 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -54,8 +54,8 @@ def get_tracer(tracer_provider=None): return tracer_provider.get_tracer(TRACER_NAME, TRACER_VERSION) -def _prepare_span_attributes( - name, session=None, extra_attributes=None, observability_options=None +def _make_tracer_and_span_attributes( + session=None, extra_attributes=None, observability_options=None ): if not HAS_OPENTELEMETRY_INSTALLED: return None, None @@ -104,8 +104,13 @@ def _prepare_span_attributes( @contextmanager def trace_call(name, session=None, extra_attributes=None, observability_options=None): - tracer, span_attributes = _prepare_span_attributes( - name, session, extra_attributes, observability_options + """ + trace_call is used in situations where you need to end a span with a context manager + or after a scope is exited. If you need to keep a span alive and lazily end it, please + invoke `trace_call_end_lazily`. + """ + tracer, span_attributes = _make_tracer_and_span_attributes( + session, extra_attributes, observability_options ) if not tracer: yield None @@ -124,10 +129,17 @@ def trace_call(name, session=None, extra_attributes=None, observability_options= span.set_status(Status(StatusCode.OK)) -def trace_end_explicitly( +def trace_call_end_lazily( name, session=None, extra_attributes=None, observability_options=None ): - tracer, span_attributes = _prepare_span_attributes( + """ + trace_call_end_lazily is used in situations where you won't have a context manager + and need to end a span explicitly when a specific condition happens. If you need a + context manager, please invoke `trace_call` with which you can invoke + `with trace_call(...) as span:` + It is the caller's responsibility to explicitly invoke span.end() + """ + tracer, span_attributes = _make_tracer_and_span_attributes( session, extra_attributes, observability_options ) if not tracer: @@ -135,3 +147,9 @@ def trace_end_explicitly( return tracer.start_span( name, kind=trace.SpanKind.CLIENT, attributes=span_attributes ) + + +def get_current_span(): + if not HAS_OPENTELEMETRY_INSTALLED: + return None + return trace.get_current_span() diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 1f177dd0c0..2e96e49373 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -27,8 +27,9 @@ _metadata_with_leader_aware_routing, ) from google.cloud.spanner_v1._opentelemetry_tracing import ( + get_current_span, trace_call, - trace_end_explicitly, + trace_call_end_lazily, ) from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1._helpers import _retry @@ -52,7 +53,7 @@ def __init__(self, session): observability_options = getattr( self._session._database, "observability_options", None ) - self.__span = trace_end_explicitly( + self.__span = trace_call_end_lazily( "CloudSpanner." + type(self).__name__, self._session, observability_options=observability_options, @@ -102,11 +103,11 @@ def update(self, table, columns, values): "update mutations inserted", dict(table=table, columns=columns) ) - def add_event_on_current_span(self, event_commentary, attributes=None): + def _add_event_on_current_span(self, event_commentary, attributes=None): current_span = get_current_span() if not current_span: return - span.add_event(event_commentary, attributes) + current_span.add_event(event_commentary, attributes) def insert_or_update(self, table, columns, values): """Insert/update one or more table rows. @@ -155,9 +156,7 @@ def delete(self, table, keyset): """ delete = Mutation.Delete(table=table, key_set=keyset._to_pb()) self._mutations.append(Mutation(delete=delete)) - self._add_event_on_current_span( - "delete mutations inserted", dict(table=table, columns=columns) - ) + self._add_event_on_current_span("delete mutations inserted", dict(table=table)) class Batch(_BatchBase): @@ -263,7 +262,7 @@ def __enter__(self): observability_options = getattr( self._session._database, "observability_options", None ) - self.__span = trace_end_explicitly( + self.__span = trace_call_end_lazily( "CloudSpanner.Batch", self._session, observability_options=observability_options, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 737b6464fd..0691ca0c2f 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -55,7 +55,7 @@ ) from google.cloud.spanner_v1._opentelemetry_tracing import ( trace_call, - trace_end_explicitly, + trace_call_end_lazily, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -1195,7 +1195,7 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" observability_options = self._database.observability_options - self.__span = trace_end_explicitly( + self.__span = trace_call_end_lazily( "CloudSpanner.Database.batch", None, observability_options=observability_options, @@ -1291,7 +1291,7 @@ def __enter__(self): attributes = dict() if self._kw: attributes["multi_use"] = self._kw["multi_use"] - self.__span = trace_end_explicitly( + self.__span = trace_call_end_lazily( "CloudSpanner.Database.snapshot", None, attributes, @@ -1346,7 +1346,7 @@ def __init__( self._exact_staleness = exact_staleness observability_options = getattr(self._database, "observability_options", {}) self.__observability_options = observability_options - self.__span = trace_end_explicitly( + self.__span = trace_call_end_lazily( "CloudSpanner.BatchSnapshot", self._session, observability_options=observability_options, diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 3ea4f42577..27d08664b3 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -616,7 +616,7 @@ def test_read_other_error(self): list(derived.read(TABLE_NAME, COLUMNS, keyset)) self.assertSpanAttributes( - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner._Derived.read", status=StatusCode.ERROR, attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) @@ -773,7 +773,7 @@ def _read_helper( ) self.assertSpanAttributes( - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner._Derived.read", attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) ), @@ -1195,10 +1195,13 @@ def _partition_read_helper( ) self.assertSpanAttributes( - "CloudSpanner.PartitionReadOnlyTransaction", + "CloudSpanner.partition_read", status=StatusCode.OK, attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) + BASE_ATTRIBUTES, + table_id=TABLE_NAME, + columns=tuple(COLUMNS), + index="0", ), ) @@ -1226,7 +1229,7 @@ def test_partition_read_other_error(self): list(derived.partition_read(TABLE_NAME, COLUMNS, keyset)) self.assertSpanAttributes( - "CloudSpanner.PartitionReadOnlyTransaction", + "CloudSpanner.partition_read", status=StatusCode.ERROR, attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) @@ -1369,7 +1372,7 @@ def _partition_query_helper( ) self.assertSpanAttributes( - "CloudSpanner.PartitionReadWriteTransaction", + "CloudSpanner.partition_query", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) @@ -1755,7 +1758,7 @@ def test_begin_ok_exact_staleness(self): ) self.assertSpanAttributes( - "CloudSpanner.BeginTransaction", + "CloudSpanner.begin", status=StatusCode.OK, attributes=BASE_ATTRIBUTES, ) @@ -1791,7 +1794,7 @@ def test_begin_ok_exact_strong(self): ) self.assertSpanAttributes( - "CloudSpanner.BeginTransaction", + "CloudSpanner.begin", status=StatusCode.OK, attributes=BASE_ATTRIBUTES, )