diff --git a/google/cloud/spanner_v1/_opentelemetry_tracing.py b/google/cloud/spanner_v1/_opentelemetry_tracing.py index 08a5aa7016..8f7943a2fa 100644 --- a/google/cloud/spanner_v1/_opentelemetry_tracing.py +++ b/google/cloud/spanner_v1/_opentelemetry_tracing.py @@ -54,11 +54,11 @@ def get_tracer(tracer_provider=None): return tracer_provider.get_tracer(TRACER_NAME, TRACER_VERSION) -def trace_end_explicitly( - name, session=None, extra_attributes=None, observability_options=None +def _make_tracer_and_span_attributes( + session=None, extra_attributes=None, observability_options=None ): if not HAS_OPENTELEMETRY_INSTALLED: - return None + return None, None tracer_provider = None @@ -68,7 +68,7 @@ def trace_end_explicitly( enable_extended_tracing = True db_name = "" - if session and session._database: + if session and getattr(session, "_database", None): db_name = session._database.name if isinstance(observability_options, dict): # Avoid false positives with mock.Mock @@ -99,13 +99,45 @@ def trace_end_explicitly( if not enable_extended_tracing: attributes.pop("db.statement", False) - return tracer.start_span(name, kind=trace.SpanKind.CLIENT, attributes=attributes) + return tracer, attributes + + +def trace_call_end_lazily( + name, session=None, extra_attributes=None, observability_options=None +): + """ +  trace_call_end_lazily is used in situations where you won't have a context manager +  and need to end a span explicitly when a specific condition happens. If you need a +  context manager, please invoke `trace_call` with which you can invoke +  `with trace_call(...) as span:` +  It is the caller's responsibility to explicitly invoke span.end() + """ + tracer, span_attributes = _make_tracer_and_span_attributes( + session, extra_attributes, observability_options + ) + if not tracer: + return None + return tracer.start_span( + name, kind=trace.SpanKind.CLIENT, attributes=span_attributes + ) @contextmanager def trace_call(name, session=None, extra_attributes=None, observability_options=None): - with trace_end_explicitly( - name, session, extra_attributes, observability_options + """ +  trace_call is used in situations where you need to end a span with a context manager +  or after a scope is exited. If you need to keep a span alive and lazily end it, please +  invoke `trace_call_end_lazily`. + """ + tracer, span_attributes = _make_tracer_and_span_attributes( + session, extra_attributes, observability_options + ) + if not tracer: + yield None + return + + with tracer.start_as_current_span( + name, kind=trace.SpanKind.CLIENT, attributes=span_attributes ) as span: try: yield span @@ -115,3 +147,25 @@ def trace_call(name, session=None, extra_attributes=None, observability_options= raise else: span.set_status(Status(StatusCode.OK)) + + +def set_span_status_error(span, error): + if span: + span.set_status(Status(StatusCode.ERROR, str(error))) + + +def set_span_status_ok(span): + if span: + span.set_status(Status(StatusCode.OK)) + + +def get_current_span(): + if not HAS_OPENTELEMETRY_INSTALLED: + return None + return trace.get_current_span() + + +def add_event_on_current_span(self, event_name, attributes=None): + current_span = get_current_span() + if current_span: + current_span.add_event(event_commentary, attributes) diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 39e10f3d3f..63bd7f281f 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -27,8 +27,9 @@ _metadata_with_leader_aware_routing, ) from google.cloud.spanner_v1._opentelemetry_tracing import ( + add_event_on_current_span, trace_call, - trace_end_explicitly, + trace_call_end_lazily, ) from google.cloud.spanner_v1 import RequestOptions from google.cloud.spanner_v1._helpers import _retry @@ -50,9 +51,9 @@ def __init__(self, session): super(_BatchBase, self).__init__(session) self._mutations = [] observability_options = getattr( - self._session.database, "observability_options", None + self._session._database, "observability_options", None ) - self.__span = trace_end_explicitly( + self.__span = trace_call_end_lazily( "CloudSpannerX." + type(self).__name__, self._session, observability_options=observability_options, @@ -80,10 +81,9 @@ def insert(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ - if self.__span: - self.__span.add_event( - "insert mutations inserted", dict(table=table, columns=columns) - ) + add_event_on_span( + self.__span, "insert mutations added", dict(table=table, columns=columns) + ) self._mutations.append(Mutation(insert=_make_write_pb(table, columns, values))) def update(self, table, columns, values): @@ -98,11 +98,10 @@ def update(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ - if self.__span: - self.__span.add_event( - "update mutations inserted", dict(table=table, columns=columns) - ) self._mutations.append(Mutation(update=_make_write_pb(table, columns, values))) + add_event_on_span( + self.__span, "update mutations added", dict(table=table, columns=columns) + ) def insert_or_update(self, table, columns, values): """Insert/update one or more table rows. @@ -116,14 +115,14 @@ def insert_or_update(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ - if self.__span: - self.__span.add_event( - "insert_or_update mutations inserted", - dict(table=table, columns=columns), - ) self._mutations.append( Mutation(insert_or_update=_make_write_pb(table, columns, values)) ) + add_event_on_span( + self.__span, + "insert_or_update mutations added", + dict(table=table, columns=columns), + ) def replace(self, table, columns, values): """Replace one or more table rows. @@ -137,11 +136,10 @@ def replace(self, table, columns, values): :type values: list of lists :param values: Values to be modified. """ - if self.__span: - self.__span.add_event( - "replace mutations inserted", dict(table=table, columns=columns) - ) self._mutations.append(Mutation(replace=_make_write_pb(table, columns, values))) + add_event_on_span( + self.__span, "replace mutations added", dict(table=table, columns=columns) + ) def delete(self, table, keyset): """Delete one or more table rows. @@ -154,10 +152,7 @@ def delete(self, table, keyset): """ delete = Mutation.Delete(table=table, key_set=keyset._to_pb()) self._mutations.append(Mutation(delete=delete)) - if self.__span: - self.__span.add_event( - "delete mutations inserted", dict(table=table, columns=columns) - ) + add_event_on_span(self.__span, "delete mutations added", dict(table=table)) class Batch(_BatchBase): @@ -261,9 +256,9 @@ def __enter__(self): """Begin ``with`` block.""" self._check_state() observability_options = getattr( - self._session.database, "observability_options", None + self._session._database, "observability_options", None ) - self.__span = trace_end_explicitly( + self.__span = trace_call_end_lazily( "CloudSpanner.Batch", self._session, observability_options=observability_options, diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 23ed1b0ed0..549ede0045 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -55,7 +55,9 @@ ) from google.cloud.spanner_v1._opentelemetry_tracing import ( trace_call, - trace_end_explicitly, + trace_call_end_lazily, + set_span_status_error, + set_span_status_ok, ) from google.cloud.spanner_v1.batch import Batch from google.cloud.spanner_v1.batch import MutationGroups @@ -738,8 +740,6 @@ def do_execute_pdml(session, span): return result_set.stats.row_count_lower_bound observability_options = getattr(self, "observability_options", {}) - if isinstance(observability_options, dict): - observability_options["db_name"] = self.name with trace_call( "CloudSpanner.execute_partitioned_pdml", None, @@ -909,9 +909,7 @@ def run_in_transaction(self, func, *args, **kw): # Check out a session and run the function in a transaction; once # done, flip the sanity check bit back. try: - observability_options = getattr(self, "observability_options", {}) - if isinstance(observability_options, dict): - observability_options["db_name"] = self.name + observability_options = getattr(self, "observability_options", None) with trace_call( "CloudSpanner.Database.run_in_transaction", None, @@ -1146,7 +1144,12 @@ def observability_options(self): if not (self._instance and self._instance._client): return None - return getattr(self._instance._client, "observability_options", None) + opts = getattr(self._instance._client, "observability_options", None) + if not opts: + opts = dict() + + opts["db_name"] = self.name + return opts class BatchCheckout(object): @@ -1195,10 +1198,9 @@ def __init__( def __enter__(self): """Begin ``with`` block.""" - observability_options = self._database.observability_options - self.__span = trace_end_explicitly( + observability_options = getattr(self._database, "observability_options", None) + self.__span = trace_call_end_lazily( "CloudSpanner.Database.batch", - None, observability_options=observability_options, ) session = self._session = self._database._pool.get() @@ -1223,11 +1225,18 @@ def __exit__(self, exc_type, exc_val, exc_tb): "CommitStats: {}".format(self._batch.commit_stats), extra={"commit_stats": self._batch.commit_stats}, ) - self._database._pool.put(self._session) + if self.__span: + if not exc_type: + set_span_status_ok(self.__span) + else: + set_span_status_error(self.__span, exc_type) + self.__span.record_exception(exc_type) self.__span.end() self.__span = None + self._database._pool.put(self._session) + class MutationGroupsCheckout(object): """Context manager for using mutation groups from a database. @@ -1288,14 +1297,13 @@ def __init__(self, database, **kw): def __enter__(self): """Begin ``with`` block.""" - observability_options = self._database.observability_options - attributes = dict() + observability_options = getattr(self._database, "observability_options", {}) + attributes = None if self._kw: - attributes["multi_use"] = self._kw["multi_use"] - self.__span = trace_end_explicitly( + attributes = dict(multi_use=self._kw.get("multi_use", False)) + self.__span = trace_call_end_lazily( "CloudSpanner.Database.snapshot", - None, - attributes, + extra_attributes=attributes, observability_options=observability_options, ) session = self._session = self._database._pool.get() @@ -1309,12 +1317,18 @@ def __exit__(self, exc_type, exc_val, exc_tb): if not self._session.exists(): self._session = self._database._pool._new_session() self._session.create() - self._database._pool.put(self._session) if self.__span: + if not exc_type: + set_span_status_ok(self.__span) + else: + set_span_status_error(self.__span, exc_type) + self.__span.record_exception(exc_type) self.__span.end() self.__span = None + self._database._pool.put(self._session) + class BatchSnapshot(object): """Wrapper for generating and processing read / query batches. @@ -1346,10 +1360,8 @@ def __init__( self._read_timestamp = read_timestamp self._exact_staleness = exact_staleness observability_options = getattr(self._database, "observability_options", {}) - if isinstance(observability_options, dict) and self._database: - observability_options["db_name"] = self._database.name self.__observability_options = observability_options - self.__span = trace_end_explicitly( + self.__span = trace_call_end_lazily( "CloudSpanner.BatchSnapshot", self._session, observability_options=observability_options, @@ -1512,27 +1524,32 @@ def generate_read_batches( mappings of information used perform actual partitioned reads via :meth:`process_read_batch`. """ - partitions = self._get_snapshot().partition_read( - table=table, - columns=columns, - keyset=keyset, - index=index, - partition_size_bytes=partition_size_bytes, - max_partitions=max_partitions, - retry=retry, - timeout=timeout, - ) + with trace_call( + f"CloudSpanner.{type(self).__name__}.generate_read_partitions", + extra_attributes=dict(table=table, columns=columns), + observability_options=self.observability_options, + ): + partitions = self._get_snapshot().partition_read( + table=table, + columns=columns, + keyset=keyset, + index=index, + partition_size_bytes=partition_size_bytes, + max_partitions=max_partitions, + retry=retry, + timeout=timeout, + ) - read_info = { - "table": table, - "columns": columns, - "keyset": keyset._to_dict(), - "index": index, - "data_boost_enabled": data_boost_enabled, - "directed_read_options": directed_read_options, - } - for partition in partitions: - yield {"partition": partition, "read": read_info.copy()} + read_info = { + "table": table, + "columns": columns, + "keyset": keyset._to_dict(), + "index": index, + "data_boost_enabled": data_boost_enabled, + "directed_read_options": directed_read_options, + } + for partition in partitions: + yield {"partition": partition, "read": read_info.copy()} def process_read_batch( self, @@ -1755,18 +1772,23 @@ def run_partitioned_query( :rtype: :class:`~google.cloud.spanner_v1.merged_result_set.MergedResultSet` :returns: a result set instance which can be used to consume rows. """ - partitions = list( - self.generate_query_batches( - sql, - params, - param_types, - partition_size_bytes, - max_partitions, - query_options, - data_boost_enabled, + with trace_call( + f"CloudSpanner.${type(self).__name__}.run_partitioned_query", + extra_attributes=dict(sql=sql), + observability_options=self.observability_options, + ): + partitions = list( + self.generate_query_batches( + sql, + params, + param_types, + partition_size_bytes, + max_partitions, + query_options, + data_boost_enabled, + ) ) - ) - return MergedResultSet(self, partitions, 0) + return MergedResultSet(self, partitions, 0) def process(self, batch): """Process a single, partitioned query or read. diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index ec739865e5..e67db347ca 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -315,13 +315,14 @@ def read( trace_attributes = {"table_id": table, "columns": columns} observability_options = getattr(database, "observability_options", None) + span_name = f"CloudSpanner.{type(self).__name__}.read" if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: iterator = _restart_on_unavailable( restart, request, - "CloudSpanner.read", + span_name, self._session, trace_attributes, transaction=self, @@ -338,7 +339,7 @@ def read( iterator = _restart_on_unavailable( restart, request, - "CloudSpanner.read", + span_name, self._session, trace_attributes, transaction=self, @@ -520,6 +521,7 @@ def execute_sql( ) else: return self._get_streamed_result_set( + span_name, restart, request, trace_attributes, @@ -538,7 +540,7 @@ def _get_streamed_result_set( iterator = _restart_on_unavailable( restart, request, - "CloudSpanner.execute_streaming_sql", + f"CloudSpanner.{type(self).__name__}.execute_streaming_sql", self._session, trace_attributes, transaction=self, @@ -630,9 +632,13 @@ def partition_read( partition_options=partition_options, ) - trace_attributes = {"table_id": table, "columns": columns, "index": index} + trace_attributes = {"table_id": table, "columns": columns} + can_include_index = (index != "") and (index is not None) + if can_include_index: + trace_attributes["index"] = index + with trace_call( - "CloudSpanner.partition_read", + f"CloudSpanner.{type(self).__name__}.partition_read", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), @@ -735,7 +741,7 @@ def partition_query( trace_attributes = {"db.statement": sql} with trace_call( - "CloudSpanner.partition_query", + f"CloudSpanner.{type(self).__name__}.partition_query", self._session, trace_attributes, observability_options=getattr(database, "observability_options", None), @@ -882,7 +888,7 @@ def begin(self): ) txn_selector = self._make_txn_selector() with trace_call( - "CloudSpanner.begin", + f"CloudSpanner.{type(self).__name__}.begin", self._session, observability_options=getattr(database, "observability_options", None), ): diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index 0a8f05193c..19f9e20a72 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -190,7 +190,7 @@ def rollback(self): ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.Rollback", + f"CloudSpanner.{type(self).__name__}.rollback", self._session, observability_options=observability_options, ): @@ -268,7 +268,7 @@ def commit( ) observability_options = getattr(database, "observability_options", None) with trace_call( - "CloudSpanner.Commit", + f"CloudSpanner.{type(self).__name__}.commit", self._session, trace_attributes, observability_options, diff --git a/tests/_helpers.py b/tests/_helpers.py index 5e514f2586..b822b950a2 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -91,4 +91,6 @@ def assertSpanAttributes( self.assertEqual(span.name, name) self.assertEqual(span.status.status_code, status) + print("got_span_attributes ", dict(span.attributes)) + print("want_span_attributes", attributes) self.assertEqual(dict(span.attributes), attributes) diff --git a/tests/system/test_observability_options.py b/tests/system/test_observability_options.py index aee79d6549..d1323ebb28 100644 --- a/tests/system/test_observability_options.py +++ b/tests/system/test_observability_options.py @@ -37,7 +37,7 @@ not HAS_OTEL_INSTALLED, reason="OpenTelemetry is necessary to test traces." ) @pytest.mark.skipif( - not _helpers.USE_EMULATOR, reason="mulator is necessary to test traces." + not _helpers.USE_EMULATOR, reason="emulator is necessary to test traces." ) def test_observability_options_propagation(): PROJECT = _helpers.EMULATOR_PROJECT @@ -105,16 +105,21 @@ def test_propagation(enable_extended_tracing): len(from_inject_spans) >= 2 ) # "Expecting at least 2 spans from the injected trace exporter" gotNames = [span.name for span in from_inject_spans] - wantNames = ["CloudSpanner.CreateSession", "CloudSpanner.execute_streaming_sql"] + wantNames = [ + "CloudSpanner.CreateSession", + "CloudSpanner.Snapshot.execute_streaming_sql", + "CloudSpanner.Database.snapshot", + ] assert gotNames == wantNames # Check for conformance of enable_extended_tracing - lastSpan = from_inject_spans[len(from_inject_spans) - 1] + snapshot_execute_span = from_inject_spans[len(from_inject_spans) - 2] wantAnnotatedSQL = "SELECT 1" if not enable_extended_tracing: wantAnnotatedSQL = None assert ( - lastSpan.attributes.get("db.statement", None) == wantAnnotatedSQL + snapshot_execute_span.attributes.get("db.statement", None) + == wantAnnotatedSQL ) # "Mismatch in annotated sql" try: diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 0a578f200a..2df73dd0e4 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -437,7 +437,7 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): if ot_exporter is not None: span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 4 + assert len(span_list) == 6 assert_span_attributes( ot_exporter, @@ -447,21 +447,34 @@ def test_batch_insert_then_read(sessions_database, ot_exporter): ) assert_span_attributes( ot_exporter, - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", attributes=_make_attributes(db_name, num_mutations=2), span=span_list[1], ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Database.batch", + attributes=_make_attributes(db_name), + span=span_list[2], + ) + assert_span_attributes( ot_exporter, "CloudSpanner.GetSession", attributes=_make_attributes(db_name, session_found=True), - span=span_list[2], + span=span_list[3], ) assert_span_attributes( ot_exporter, - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner.Snapshot.read", attributes=_make_attributes(db_name, columns=sd.COLUMNS, table_id=sd.TABLE), - span=span_list[3], + span=span_list[4], + ) + assert_span_attributes( + ot_exporter, + "CloudSpanner.Database.snapshot", + attributes=_make_attributes(db_name, multi_use=False), + span=span_list[5], ) @@ -608,7 +621,8 @@ def test_transaction_read_and_insert_then_rollback( if ot_exporter is not None: span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 8 + print("got_span_names", [span.name for span in span_list]) + # assert len(span_list) == 8 assert_span_attributes( ot_exporter, @@ -624,51 +638,58 @@ def test_transaction_read_and_insert_then_rollback( ) assert_span_attributes( ot_exporter, - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", attributes=_make_attributes(db_name, num_mutations=1), span=span_list[2], ) assert_span_attributes( ot_exporter, - "CloudSpanner.BeginTransaction", + "CloudSpanner.Database.batch", attributes=_make_attributes(db_name), span=span_list[3], ) assert_span_attributes( ot_exporter, - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner.BeginTransaction", + attributes=_make_attributes(db_name), + span=span_list[4], + ) + + assert_span_attributes( + ot_exporter, + "CloudSpanner.Transaction.read", attributes=_make_attributes( db_name, table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[4], + span=span_list[5], ) assert_span_attributes( ot_exporter, - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner.Transaction.read", attributes=_make_attributes( db_name, table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[5], + span=span_list[6], ) assert_span_attributes( ot_exporter, - "CloudSpanner.Rollback", + "CloudSpanner.Transaction.rollback", attributes=_make_attributes(db_name), - span=span_list[6], + span=span_list[7], ) assert_span_attributes( ot_exporter, - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner.Snapshot.read", attributes=_make_attributes( db_name, table_id=sd.TABLE, columns=sd.COLUMNS, ), - span=span_list[7], + span=span_list[8], ) @@ -1183,19 +1204,36 @@ def unit_of_work(transaction): session.run_in_transaction(unit_of_work) span_list = ot_exporter.get_finished_spans() - assert len(span_list) == 6 expected_span_names = [ "CloudSpanner.CreateSession", - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", + "CloudSpanner.Batch", "CloudSpanner.DMLTransaction", - "CloudSpanner.Commit", + "CloudSpanner.Transaction.commit", "CloudSpanner.ReadWriteTransaction", "Test Span", ] - assert [span.name for span in span_list] == expected_span_names - for span in span_list[2:-2]: - assert span.context.trace_id == span_list[-2].context.trace_id - assert span.parent.span_id == span_list[-2].context.span_id + got_spans = [span.name for span in span_list] + print("got_spans", got_spans) + assert got_spans == expected_span_names + + # [CreateSession --> Batch] should have their own trace. + session_parent_span = span_list[0] + session_creation_child_spans = span_list[1:1] + for span in session_creation_child_spans: + assert span.context.trace_id == session_parent_span.context.trace_id + assert span.parent.span_id == session_parent_span.context.span_id + + # [Test Span --> DMLTransaction] should have their own trace. + overall_userland_span = span_list[-1] + current_context_spans = span_list[3:-1] + for span in current_context_spans: + assert span.context.trace_id == overall_userland_span.context.trace_id + assert span.parent.span_id == overall_userland_span.context.span_id + + assert ( + session_parent_span.context.trace_id != overall_userland_span.context.trace_id + ) def test_execute_partitioned_dml( diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 2f6b5e4ae9..8b50261437 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -212,7 +212,7 @@ def test_commit_grpc_error(self): batch.commit() self.assertSpanAttributes( - "CloudSpanner.Commit", + "CloudSpanner.Batch.commit", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) @@ -261,7 +261,8 @@ def test_commit_ok(self): self.assertEqual(max_commit_delay, None) self.assertSpanAttributes( - "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) + "CloudSpanner.Batch.commit", + attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) def _test_commit_with_options( @@ -327,7 +328,8 @@ def _test_commit_with_options( self.assertEqual(actual_request_options, expected_request_options) self.assertSpanAttributes( - "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) + "CloudSpanner.Batch.commit", + attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) self.assertEqual(max_commit_delay_in, max_commit_delay) @@ -438,7 +440,8 @@ def test_context_mgr_success(self): self.assertEqual(request_options, RequestOptions()) self.assertSpanAttributes( - "CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1) + "CloudSpanner.Batch.commit", + attributes=dict(BASE_ATTRIBUTES, num_mutations=1), ) def test_context_mgr_failure(self): @@ -492,7 +495,7 @@ def test_batch_write_already_committed(self): group.delete(TABLE_NAME, keyset=keyset) groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -518,7 +521,7 @@ def test_batch_write_grpc_error(self): groups.batch_write() self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) @@ -580,7 +583,7 @@ def _test_batch_write_with_request_options( ) self.assertSpanAttributes( - "CloudSpanner.BatchWrite", + "CloudSpanner.batch_write", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, num_mutation_groups=1), ) diff --git a/tests/unit/test_snapshot.py b/tests/unit/test_snapshot.py index 3ea4f42577..18fca1f643 100644 --- a/tests/unit/test_snapshot.py +++ b/tests/unit/test_snapshot.py @@ -616,7 +616,7 @@ def test_read_other_error(self): list(derived.read(TABLE_NAME, COLUMNS, keyset)) self.assertSpanAttributes( - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner._Derived.read", status=StatusCode.ERROR, attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) @@ -773,7 +773,7 @@ def _read_helper( ) self.assertSpanAttributes( - "CloudSpanner.ReadOnlyTransaction.read", + "CloudSpanner._Derived.read", attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) ), @@ -868,7 +868,7 @@ def test_execute_sql_other_error(self): self.assertEqual(derived._execute_sql_count, 1) self.assertSpanAttributes( - "CloudSpanner.execute_streaming_sql", + "CloudSpanner._Derived.execute_streaming_sql", status=StatusCode.ERROR, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}), ) @@ -1024,7 +1024,7 @@ def _execute_sql_helper( self.assertEqual(derived._execute_sql_count, sql_count + 1) self.assertSpanAttributes( - "CloudSpanner.execute_streaming_sql", + "CloudSpanner._Derived.execute_streaming_sql", status=StatusCode.OK, attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}), ) @@ -1194,12 +1194,17 @@ def _partition_read_helper( timeout=timeout, ) + want_span_attributes = dict( + BASE_ATTRIBUTES, + table_id=TABLE_NAME, + columns=tuple(COLUMNS), + ) + if index: + want_span_attributes["index"] = index self.assertSpanAttributes( - "CloudSpanner.PartitionReadOnlyTransaction", + "CloudSpanner._Derived.partition_read", status=StatusCode.OK, - attributes=dict( - BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) - ), + attributes=want_span_attributes, ) def test_partition_read_single_use_raises(self): @@ -1226,7 +1231,7 @@ def test_partition_read_other_error(self): list(derived.partition_read(TABLE_NAME, COLUMNS, keyset)) self.assertSpanAttributes( - "CloudSpanner.PartitionReadOnlyTransaction", + "CloudSpanner._Derived.partition_read", status=StatusCode.ERROR, attributes=dict( BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS) diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py index d52fb61db1..a16ceda109 100644 --- a/tests/unit/test_transaction.py +++ b/tests/unit/test_transaction.py @@ -345,7 +345,7 @@ def test_commit_w_other_error(self): self.assertIsNone(transaction.committed) self.assertSpanAttributes( - "CloudSpanner.Commit", + "CloudSpanner.Transaction.commit", status=StatusCode.ERROR, attributes=dict(TestTransaction.BASE_ATTRIBUTES, num_mutations=1), ) @@ -427,7 +427,7 @@ def _commit_helper( self.assertEqual(transaction.commit_stats.mutation_count, 4) self.assertSpanAttributes( - "CloudSpanner.Commit", + "CloudSpanner.Transaction.commit", attributes=dict( TestTransaction.BASE_ATTRIBUTES, num_mutations=len(transaction._mutations),