Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update retry strategy for mutation calls to handle aborted transactions #1270

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,7 @@ def commit(
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
response = self._session.run_in_transaction(method)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm.... Now that I see this, I realize that this cannot use run_in_transaction directly. The reason is that:

  1. This method batch.commit() in its original form (so without the changes in this pull request) creates and commits a single-use read/write transaction.
  2. run_in_transaction however creates a new read/write transaction and then executes the given method in the scope of that transaction.

So what happens now is that you create two transactions:

  1. run_in_transaction creates a transaction that is not being used.
  2. method that is passed in to run_in_transaction executes a Commit in a single-use read/write transaction.

self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
return self.committed
Expand Down
98 changes: 98 additions & 0 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.


import time
import unittest
from tests._helpers import (
OpenTelemetryBase,
StatusCode,
enrich_with_otel_scope,
)
from google.cloud.spanner_v1 import RequestOptions
from unittest.mock import patch

TABLE_NAME = "citizens"
COLUMNS = ["email", "first_name", "last_name", "age"]
Expand Down Expand Up @@ -263,6 +265,71 @@ def test_commit_ok(self):
self.assertSpanAttributes(
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
)


def test_aborted_exception_on_commit(self):
# Test case to verify that an Aborted exception is raised when
# batch.commit() is called and the transaction is aborted internally.
import datetime
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.api_core.exceptions import Aborted

now = datetime.datetime.utcnow().replace(tzinfo=UTC)
now_pb = _datetime_to_pb_timestamp(now)
response = CommitResponse(commit_timestamp=now_pb)
database = _Database()
# Setup the spanner API which throws Aborted exception when calling commit API.
api = database.spanner_api = _FauxSpannerAPI(_commit_response=response, _aborted_error = True)

# Create mock session and batch objects
session = _Session(database)
batch = self._make_one(session)
batch.insert(TABLE_NAME, COLUMNS, VALUES)

# Assertion: Ensure that calling batch.commit() raises the Aborted exception
with self.assertRaises(Aborted) as context:
batch.commit()

# Verify additional details about the exception
self.assertEqual(str(context.exception), "409 Transaction was aborted")

def test_aborted_exception_on_commit_with_retries(self):
# Test case to verify that an Aborted exception is raised when
# batch.commit() is invoked, the transaction is internally aborted,
# and the Spanner commit API calls are retried.
import datetime
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.api_core.exceptions import Aborted

now = datetime.datetime.utcnow().replace(tzinfo=UTC)
now_pb = _datetime_to_pb_timestamp(now)
response = CommitResponse(commit_timestamp=now_pb)
database = _Database()
# Setup the spanner API which throws Aborted exception when calling commit API.
api = database.spanner_api = _FauxSpannerAPI(_commit_response=response, _aborted_error = True)

# We will patch the commit method of _FauxSpannerAPI to track how many times it's called
with patch.object(api, 'commit', wraps=api.commit) as mock_commit:
# Set up a mock session and batch
session = _Session(database)
batch = self._make_one(session)
batch.insert(TABLE_NAME, COLUMNS, VALUES)

# Try committing the batch, which should call api.commit() and raise an Aborted exception
# The retry logic should call commit() again after handling the Aborted exception
try:
batch.commit()
except Aborted:
pass

# Verify that commit was called more than once (due to retry)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should verify the exact number. Based on my comment above, my guess is that you are getting 1 more commit than you would expect, as run_in_transaction also creates a read/write transaction.

self.assertGreater(mock_commit.call_count, 1, "api.commit() was not called more than once on retry")

def _test_commit_with_options(
self,
Expand Down Expand Up @@ -614,6 +681,33 @@ def __init__(self, database=None, name=TestBatch.SESSION_NAME):
@property
def session_id(self):
return self.name

def run_in_transaction(self, fnc):
"""
Runs a function in a transaction, retrying if an exception occurs.
:param fnc: The function to run in the transaction.
:param max_retries: Maximum number of retry attempts.
:param delay: Delay (in seconds) between retries.
:return: The result of the function, or raises the exception after max retries.
"""
from google.api_core.exceptions import Aborted
attempt = 0
max_retries = 3
delay = 1
while attempt < max_retries:
try:
result = fnc()
return result
except Aborted as exc:
attempt += 1
if attempt < max_retries:
print(f"Attempt {attempt} failed with Aborted. Retrying in {delay} seconds...")
time.sleep(delay) # Wait before retrying
else:
raise exc # After max retries, raise the exception
except Exception as exc:
print(f"Unexpected exception occurred: {exc}")
raise # Raise any other unexpected exception immediately


class _Database(object):
Expand All @@ -627,6 +721,7 @@ class _FauxSpannerAPI:
_committed = None
_batch_request = None
_rpc_error = False
_aborted_error = False

def __init__(self, **kwargs):
self.__dict__.update(**kwargs)
Expand All @@ -637,6 +732,7 @@ def commit(
metadata=None,
):
from google.api_core.exceptions import Unknown
from google.api_core.exceptions import Aborted

max_commit_delay = None
if type(request).pb(request).HasField("max_commit_delay"):
Expand All @@ -653,6 +749,8 @@ def commit(
)
if self._rpc_error:
raise Unknown("error")
if self._aborted_error:
raise Aborted("Transaction was aborted")
return self._commit_response

def batch_write(
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,7 +1817,7 @@ def test_context_mgr_success(self):
api = database.spanner_api = self._make_spanner_client()
api.commit.return_value = response
pool = database._pool = _Pool()
session = _Session(database)
session = _Session(database, run_transaction_function = True)
pool.put(session)
checkout = self._make_one(
database, request_options={"transaction_tag": self.TRANSACTION_TAG}
Expand Down Expand Up @@ -1866,7 +1866,7 @@ def test_context_mgr_w_commit_stats_success(self):
api = database.spanner_api = self._make_spanner_client()
api.commit.return_value = response
pool = database._pool = _Pool()
session = _Session(database)
session = _Session(database, run_transaction_function = True)
pool.put(session)
checkout = self._make_one(database)

Expand Down Expand Up @@ -1910,7 +1910,7 @@ def test_context_mgr_w_commit_stats_error(self):
api = database.spanner_api = self._make_spanner_client()
api.commit.side_effect = Unknown("testing")
pool = database._pool = _Pool()
session = _Session(database)
session = _Session(database, run_transaction_function = True)
pool.put(session)
checkout = self._make_one(database)

Expand Down Expand Up @@ -1946,7 +1946,7 @@ def test_context_mgr_failure(self):

database = _Database(self.DATABASE_NAME)
pool = database._pool = _Pool()
session = _Session(database)
session = _Session(database, run_transaction_function = True)
pool.put(session)
checkout = self._make_one(database)

Expand Down Expand Up @@ -3094,7 +3094,6 @@ class Testing(Exception):
self.assertEqual(pool._session, session)
pool._new_session.assert_not_called()


def _make_instance_api():
from google.cloud.spanner_admin_instance_v1 import InstanceAdminClient

Expand Down Expand Up @@ -3181,10 +3180,11 @@ def __init__(
self._database = database
self.name = name
self._run_transaction_function = run_transaction_function
self._committed = False

def run_in_transaction(self, func, *args, **kw):
if self._run_transaction_function:
func(*args, **kw)
return func(*args, **kw)
self._retried = (func, args, kw)
return self._committed

Expand Down
Loading