Skip to content

Commit

Permalink
Reduce edit changes
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Nov 18, 2024
1 parent a1c45aa commit bc1befe
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion google/cloud/spanner_v1/_opentelemetry_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def trace_call(name, session, extra_attributes=None, observability_options=None)
attributes = {
"db.type": "spanner",
"db.url": SpannerClient.DEFAULT_ENDPOINT,
"db.instance": session._database.name,
"db.instance": "" if not session._database else session._database.name,
"net.host.name": SpannerClient.DEFAULT_ENDPOINT,
OTEL_SCOPE_NAME: TRACER_NAME,
OTEL_SCOPE_VERSION: TRACER_VERSION,
Expand Down
35 changes: 16 additions & 19 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,26 +890,23 @@ def run_in_transaction(self, func, *args, **kw):
# Sanity check: Is there a transaction already running?
# If there is, then raise a red flag. Otherwise, mark that this one
# is running.
with SessionCheckout(self._pool) as session:
observability_options = getattr(self, "observability_options", None)
with trace_call(
"CloudSpanner.Database.run_in_transaction",
session,
observability_options=observability_options,
):
# Sanity check: Is there a transaction already running?
# If there is, then raise a red flag. Otherwise, mark that this one
# is running.
if getattr(self._local, "transaction_running", False):
raise RuntimeError("Spanner does not support nested transactions.")
self._local.transaction_running = True

# Check out a session and run the function in a transaction; once
# done, flip the sanity check bit back.
try:
if getattr(self._local, "transaction_running", False):
raise RuntimeError("Spanner does not support nested transactions.")
self._local.transaction_running = True

# Check out a session and run the function in a transaction; once
# done, flip the sanity check bit back.
try:
with SessionCheckout(self._pool) as session:
observability_options = getattr(self, "observability_options", None)
with trace_call(
"CloudSpanner.Database.run_in_transaction",
session,
observability_options=observability_options,
):
return session.run_in_transaction(func, *args, **kw)
finally:
self._local.transaction_running = False
finally:
self._local.transaction_running = False

def restore(self, source):
"""Restore from a backup to this database.
Expand Down
9 changes: 6 additions & 3 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def run_in_transaction(self, func, *args, **kw):
observability_options = getattr(self._database, "observability_options", None)
attempts = 0

def __run_txn(txn, attempts):
def __run_txn_and_return(txn, attempts):
try:
return_value = func(txn, *args, **kw)
except Aborted as exc:
Expand Down Expand Up @@ -457,6 +457,9 @@ def __run_txn(txn, attempts):
)
return return_value, True

# Signal to the caller to continue iterating.
return None, False

while True:
if self._transaction is None:
with trace_call(
Expand All @@ -467,12 +470,12 @@ def __run_txn(txn, attempts):
txn.exclude_txn_from_change_streams = (
exclude_txn_from_change_streams
)
return_value, completed = __run_txn(txn, attempts)
return_value, completed = __run_txn_and_return(txn, attempts)
if completed:
return return_value
else:
txn = self._transaction
return_value, completed = __run_txn(txn, attempts)
return_value, completed = __run_txn_and_return(txn, attempts)
if completed:
return return_value

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def test_execute_sql_other_error(self):
self.assertEqual(derived._execute_sql_count, 1)

self.assertSpanAttributes(
"CloudSpanner.ReadWriteTransaction",
"CloudSpanner.execute_sql",
status=StatusCode.ERROR,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY}),
)
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def _execute_sql_helper(
self.assertEqual(derived._execute_sql_count, sql_count + 1)

self.assertSpanAttributes(
"CloudSpanner.ReadWriteTransaction",
"CloudSpanner.execute_sql",
status=StatusCode.OK,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}),
)
Expand Down

0 comments on commit bc1befe

Please sign in to comment.