Skip to content

Commit

Permalink
Make updates and fix up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Dec 2, 2024
1 parent 54bed9c commit 3ff4176
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 65 deletions.
32 changes: 25 additions & 7 deletions google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ def _make_tracer_and_span_attributes(

if not enable_extended_tracing:
attributes.pop("db.statement", False)
attributes.pop("sql", False)
else:
# Otherwise there are places where the annotated sql was inserted
# directly from the arguments as "sql", and transform those into "db.statement".
db_statement = attributes.get("db.statement", None)
if not db_statement:
sql = attributes.get("sql", None)
if sql:
attributes = attributes.copy()
attributes.pop("sql", False)
attributes["db.statement"] = sql

return tracer, attributes

Expand All @@ -111,7 +122,10 @@ def trace_call_end_lazily(
 context manager, please invoke `trace_call` with which you can invoke
 `with trace_call(...) as span:`
 It is the caller's responsibility to explicitly invoke span.end()
"""
"""
if not name:
return None

tracer, span_attributes = _make_tracer_and_span_attributes(
session, extra_attributes, observability_options
)
Expand All @@ -128,7 +142,11 @@ def trace_call(name, session=None, extra_attributes=None, observability_options=
 trace_call is used in situations where you need to end a span with a context manager
 or after a scope is exited. If you need to keep a span alive and lazily end it, please
 invoke `trace_call_end_lazily`.
"""
"""
if not name:
yield None
return

tracer, span_attributes = _make_tracer_and_span_attributes(
session, extra_attributes, observability_options
)
Expand Down Expand Up @@ -165,9 +183,9 @@ def get_current_span():
return trace.get_current_span()


def add_event_on_current_span(self, event_name, attributes=None, current_span=None):
if not current_span:
current_span = get_current_span()
def add_event_on_current_span(event_name, attributes=None, span=None):
if not span:
span = get_current_span()

if current_span:
current_span.add_event(event_name, attributes)
if span:
span.add_event(event_name, attributes)
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, session):
super(_BatchBase, self).__init__(session)
self._mutations = []
self.__span = trace_call_end_lazily(
f"CloudSpanner.{type(self).__name}",
f"CloudSpanner.{type(self).__name__}",
self._session,
None,
getattr(self._session._database, "observability_options", None),
Expand Down
98 changes: 52 additions & 46 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,16 +702,14 @@ def execute_partitioned_dml(

def execute_pdml():
def do_execute_pdml(session, span):
add_event_on_current_span(
"Starting BeginTransaction", current_span=span
)
add_event_on_current_span("Starting BeginTransaction", span=span)
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)
add_event_on_current_span(
"Completed BeginTransaction",
{"transaction.id": txn.id},
current_span=span,
span=span,
)
txn_selector = TransactionSelector(id=txn.id)

Expand All @@ -731,7 +729,7 @@ def do_execute_pdml(session, span):
iterator = _restart_on_unavailable(
method=method,
request=request,
trace_name="CloudSpannerOperation.ExecuteStreamingSql",
span_name="CloudSpanner.ExecuteStreamingSql",
transaction_selector=txn_selector,
observability_options=self.observability_options,
)
Expand All @@ -741,11 +739,9 @@ def do_execute_pdml(session, span):

return result_set.stats.row_count_lower_bound

observability_options = getattr(self, "observability_options", {})
with trace_call(
"CloudSpanner.execute_partitioned_pdml",
None,
observability_options=observability_options,
"CloudSpanner.Database.execute_partitioned_pdml",
observability_options=self.observability_options,
) as span:
with SessionCheckout(self._pool) as session:
return do_execute_pdml(session, span)
Expand Down Expand Up @@ -1232,8 +1228,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_type:
set_span_status_ok(self.__span)
else:
set_span_status_error(self.__span, exc_type)
self.__span.record_exception(exc_type)
set_span_status_error(self.__span, exc_val)
self.__span.record_exception(exc_val)
self.__span.end()
self.__span = None

Expand Down Expand Up @@ -1324,8 +1320,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):
if not exc_type:
set_span_status_ok(self.__span)
else:
set_span_status_error(self.__span, exc_type)
self.__span.record_exception(exc_type)
set_span_status_error(self.__span, exc_val)
self.__span.record_exception(exc_val)
self.__span.end()
self.__span = None

Expand Down Expand Up @@ -1527,7 +1523,7 @@ def generate_read_batches(
:meth:`process_read_batch`.
"""
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_read_partitions",
f"CloudSpanner.{type(self).__name__}.generate_read_batches",
extra_attributes=dict(table=table, columns=columns),
observability_options=self.observability_options,
):
Expand Down Expand Up @@ -1578,10 +1574,8 @@ def process_read_batch(
:returns: a result set instance which can be used to consume rows.
"""
observability_options = self.observability_options or {}
session = self._get_session()
klassname = type(self).__name__
with trace_call(
"CloudSpanner." + klassname + ".process_read_batch",
f"CloudSpanner.{type(self).__name__}.process_read_batch",
session,
observability_options=observability_options,
):
Expand Down Expand Up @@ -1665,34 +1659,39 @@ def generate_query_batches(
mappings of information used perform actual partitioned reads via
:meth:`process_read_batch`.
"""
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.generate_query_batches",
extra_attributes=dict(sql=sql),
observability_options=self.observability_options,
):
partitions = self._get_snapshot().partition_query(
sql=sql,
params=params,
param_types=param_types,
partition_size_bytes=partition_size_bytes,
max_partitions=max_partitions,
retry=retry,
timeout=timeout,
)

query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)
query_info = {
"sql": sql,
"data_boost_enabled": data_boost_enabled,
"directed_read_options": directed_read_options,
}
if params:
query_info["params"] = params
query_info["param_types"] = param_types

# Query-level options have higher precedence than client-level and
# environment-level options
default_query_options = self._database._instance._client._query_options
query_info["query_options"] = _merge_query_options(
default_query_options, query_options
)

for partition in partitions:
yield {"partition": partition, "query": query_info}
for partition in partitions:
yield {"partition": partition, "query": query_info}

def process_query_batch(
self,
Expand All @@ -1717,9 +1716,16 @@ def process_query_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
return self._get_snapshot().execute_sql(
partition=batch["partition"], **batch["query"], retry=retry, timeout=timeout
)
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_query_batch",
observability_options=self.observability_options,
):
return self._get_snapshot().execute_sql(
partition=batch["partition"],
**batch["query"],
retry=retry,
timeout=timeout,
)

def run_partitioned_query(
self,
Expand Down
8 changes: 4 additions & 4 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
def _restart_on_unavailable(
method,
request,
trace_name=None,
span_name=None,
session=None,
attributes=None,
transaction=None,
Expand Down Expand Up @@ -87,7 +87,7 @@ def _restart_on_unavailable(
request.transaction = transaction_selector

with trace_call(
trace_name, session, attributes, observability_options=observability_options
span_name, session, attributes, observability_options=observability_options
):
iterator = method(request=request)
while True:
Expand All @@ -109,7 +109,7 @@ def _restart_on_unavailable(
except ServiceUnavailable:
del item_buffer[:]
with trace_call(
trace_name,
span_name,
session,
attributes,
observability_options=observability_options,
Expand All @@ -129,7 +129,7 @@ def _restart_on_unavailable(
raise
del item_buffer[:]
with trace_call(
trace_name,
span_name,
session,
attributes,
observability_options=observability_options,
Expand Down
14 changes: 12 additions & 2 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,16 @@ def tearDown(self):

def assertNoSpans(self):
if HAS_OPENTELEMETRY_INSTALLED:
span_list = self.ot_exporter.get_finished_spans()
span_list = self.get_finished_spans()
print("got_span_list", [span.name for span in span_list])
self.assertEqual(len(span_list), 0)

def assertSpanAttributes(
self, name, status=StatusCode.OK, attributes=None, span=None
):
if HAS_OPENTELEMETRY_INSTALLED:
if not span:
span_list = self.ot_exporter.get_finished_spans()
span_list = self.get_finished_spans()
self.assertEqual(len(span_list), 1)
span = span_list[0]

Expand All @@ -94,3 +95,12 @@ def assertSpanAttributes(
print("got_span_attributes ", dict(span.attributes))
print("want_span_attributes", attributes)
self.assertEqual(dict(span.attributes), attributes)

def get_finished_spans(self):
if not HAS_OPENTELEMETRY_INSTALLED:
return []

spans = self.ot_exporter.get_finished_spans()
# A span with name=None is the result from invoking trace_call without
# intention to trace, hence these have to be filtered out.
return list(filter(lambda span: span.name, spans))
1 change: 0 additions & 1 deletion tests/system/test_session_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,6 @@ def unit_of_work(transaction):
"Test Span",
]
got_spans = [span.name for span in span_list]
print("got_spans", got_spans)
assert got_spans == expected_span_names

# [CreateSession --> Batch] should have their own trace.
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ def _partition_query_helper(
)

self.assertSpanAttributes(
"CloudSpanner.PartitionReadWriteTransaction",
"CloudSpanner._Derived.partition_query",
status=StatusCode.OK,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}),
)
Expand All @@ -1392,7 +1392,7 @@ def test_partition_query_other_error(self):
list(derived.partition_query(SQL_QUERY))

self.assertSpanAttributes(
"CloudSpanner.PartitionReadWriteTransaction",
"CloudSpanner._Derived.partition_query",
status=StatusCode.ERROR,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}),
)
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/test_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def test_rollback_w_other_error(self):
self.assertFalse(transaction.rolled_back)

self.assertSpanAttributes(
"CloudSpanner.Rollback",
"CloudSpanner.Transaction.rollback",
status=StatusCode.ERROR,
attributes=TestTransaction.BASE_ATTRIBUTES,
)
Expand Down Expand Up @@ -299,7 +299,8 @@ def test_rollback_ok(self):
)

self.assertSpanAttributes(
"CloudSpanner.Rollback", attributes=TestTransaction.BASE_ATTRIBUTES
"CloudSpanner.Transaction.rollback",
attributes=TestTransaction.BASE_ATTRIBUTES,
)

def test_commit_not_begun(self):
Expand Down

0 comments on commit 3ff4176

Please sign in to comment.