Skip to content

Commit

Permalink
fix: update retry strategy for mutation calls to handle aborted trans…
Browse files Browse the repository at this point in the history
…actions
  • Loading branch information
aakashanandg committed Dec 16, 2024
1 parent a6811af commit 51ec536
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 10 deletions.
5 changes: 1 addition & 4 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,7 @@ def commit(
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
response = self._session.run_in_transaction(method)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
return self.committed
Expand Down
98 changes: 98 additions & 0 deletions tests/unit/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
# limitations under the License.


import time
import unittest
from tests._helpers import (
OpenTelemetryBase,
StatusCode,
enrich_with_otel_scope,
)
from google.cloud.spanner_v1 import RequestOptions
from unittest.mock import patch

TABLE_NAME = "citizens"
COLUMNS = ["email", "first_name", "last_name", "age"]
Expand Down Expand Up @@ -263,6 +265,71 @@ def test_commit_ok(self):
self.assertSpanAttributes(
"CloudSpanner.Commit", attributes=dict(BASE_ATTRIBUTES, num_mutations=1)
)


def test_aborted_exception_on_commit(self):
# Test case to verify that an Aborted exception is raised when
# batch.commit() is called and the transaction is aborted internally.
import datetime
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.api_core.exceptions import Aborted

now = datetime.datetime.utcnow().replace(tzinfo=UTC)
now_pb = _datetime_to_pb_timestamp(now)
response = CommitResponse(commit_timestamp=now_pb)
database = _Database()
# Setup the spanner API which throws Aborted exception when calling commit API.
api = database.spanner_api = _FauxSpannerAPI(_commit_response=response, _aborted_error = True)

# Create mock session and batch objects
session = _Session(database)
batch = self._make_one(session)
batch.insert(TABLE_NAME, COLUMNS, VALUES)

# Assertion: Ensure that calling batch.commit() raises the Aborted exception
with self.assertRaises(Aborted) as context:
batch.commit()

# Verify additional details about the exception
self.assertEqual(str(context.exception), "409 Transaction was aborted")

def test_aborted_exception_on_commit_with_retries(self):
# Test case to verify that an Aborted exception is raised when
# batch.commit() is invoked, the transaction is internally aborted,
# and the Spanner commit API calls are retried.
import datetime
from google.cloud.spanner_v1 import CommitResponse
from google.cloud.spanner_v1 import TransactionOptions
from google.cloud._helpers import UTC
from google.cloud._helpers import _datetime_to_pb_timestamp
from google.api_core.exceptions import Aborted

now = datetime.datetime.utcnow().replace(tzinfo=UTC)
now_pb = _datetime_to_pb_timestamp(now)
response = CommitResponse(commit_timestamp=now_pb)
database = _Database()
# Setup the spanner API which throws Aborted exception when calling commit API.
api = database.spanner_api = _FauxSpannerAPI(_commit_response=response, _aborted_error = True)

# We will patch the commit method of _FauxSpannerAPI to track how many times it's called
with patch.object(api, 'commit', wraps=api.commit) as mock_commit:
# Set up a mock session and batch
session = _Session(database)
batch = self._make_one(session)
batch.insert(TABLE_NAME, COLUMNS, VALUES)

# Try committing the batch, which should call api.commit() and raise an Aborted exception
# The retry logic should call commit() again after handling the Aborted exception
try:
batch.commit()
except Aborted:
pass

# Verify that commit was called more than once (due to retry)
self.assertGreater(mock_commit.call_count, 1, "api.commit() was not called more than once on retry")

def _test_commit_with_options(
self,
Expand Down Expand Up @@ -614,6 +681,33 @@ def __init__(self, database=None, name=TestBatch.SESSION_NAME):
@property
def session_id(self):
return self.name

def run_in_transaction(self, fnc):
"""
Runs a function in a transaction, retrying if an exception occurs.
:param fnc: The function to run in the transaction.
:param max_retries: Maximum number of retry attempts.
:param delay: Delay (in seconds) between retries.
:return: The result of the function, or raises the exception after max retries.
"""
from google.api_core.exceptions import Aborted
attempt = 0
max_retries = 3
delay = 1
while attempt < max_retries:
try:
result = fnc()
return result
except Aborted as exc:
attempt += 1
if attempt < max_retries:
print(f"Attempt {attempt} failed with Aborted. Retrying in {delay} seconds...")
time.sleep(delay) # Wait before retrying
else:
raise exc # After max retries, raise the exception
except Exception as exc:
print(f"Unexpected exception occurred: {exc}")
raise # Raise any other unexpected exception immediately


class _Database(object):
Expand All @@ -627,6 +721,7 @@ class _FauxSpannerAPI:
_committed = None
_batch_request = None
_rpc_error = False
_aborted_error = False

def __init__(self, **kwargs):
self.__dict__.update(**kwargs)
Expand All @@ -637,6 +732,7 @@ def commit(
metadata=None,
):
from google.api_core.exceptions import Unknown
from google.api_core.exceptions import Aborted

max_commit_delay = None
if type(request).pb(request).HasField("max_commit_delay"):
Expand All @@ -653,6 +749,8 @@ def commit(
)
if self._rpc_error:
raise Unknown("error")
if self._aborted_error:
raise Aborted("Transaction was aborted")
return self._commit_response

def batch_write(
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1817,7 +1817,7 @@ def test_context_mgr_success(self):
api = database.spanner_api = self._make_spanner_client()
api.commit.return_value = response
pool = database._pool = _Pool()
session = _Session(database)
session = _Session(database, run_transaction_function = True)
pool.put(session)
checkout = self._make_one(
database, request_options={"transaction_tag": self.TRANSACTION_TAG}
Expand Down Expand Up @@ -1866,7 +1866,7 @@ def test_context_mgr_w_commit_stats_success(self):
api = database.spanner_api = self._make_spanner_client()
api.commit.return_value = response
pool = database._pool = _Pool()
session = _Session(database)
session = _Session(database, run_transaction_function = True)
pool.put(session)
checkout = self._make_one(database)

Expand Down Expand Up @@ -1910,7 +1910,7 @@ def test_context_mgr_w_commit_stats_error(self):
api = database.spanner_api = self._make_spanner_client()
api.commit.side_effect = Unknown("testing")
pool = database._pool = _Pool()
session = _Session(database)
session = _Session(database, run_transaction_function = True)
pool.put(session)
checkout = self._make_one(database)

Expand Down Expand Up @@ -1946,7 +1946,7 @@ def test_context_mgr_failure(self):

database = _Database(self.DATABASE_NAME)
pool = database._pool = _Pool()
session = _Session(database)
session = _Session(database, run_transaction_function = True)
pool.put(session)
checkout = self._make_one(database)

Expand Down Expand Up @@ -3094,7 +3094,6 @@ class Testing(Exception):
self.assertEqual(pool._session, session)
pool._new_session.assert_not_called()


def _make_instance_api():
from google.cloud.spanner_admin_instance_v1 import InstanceAdminClient

Expand Down Expand Up @@ -3181,10 +3180,11 @@ def __init__(
self._database = database
self.name = name
self._run_transaction_function = run_transaction_function
self._committed = False

def run_in_transaction(self, func, *args, **kw):
if self._run_transaction_function:
func(*args, **kw)
return func(*args, **kw)
self._retried = (func, args, kw)
return self._committed

Expand Down

0 comments on commit 51ec536

Please sign in to comment.