Skip to content

Commit

Permalink
test: add test to verify that transactions are retried
Browse files Browse the repository at this point in the history
  • Loading branch information
olavloite committed Dec 12, 2024
1 parent a6811af commit 1c0e4cf
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 0 deletions.
13 changes: 13 additions & 0 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import inspect
import grpc
from concurrent import futures

from google.protobuf import empty_pb2
from grpc_status.rpc_status import _Status
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc
Expand All @@ -28,6 +30,7 @@
class MockSpanner:
def __init__(self):
self.results = {}
self.errors = {}

def add_result(self, sql: str, result: result_set.ResultSet):
self.results[sql.lower().strip()] = result
Expand All @@ -38,6 +41,15 @@ def get_result(self, sql: str) -> result_set.ResultSet:
raise ValueError(f"No result found for {sql}")
return result

def add_error(self, method: str, error: _Status):
self.errors[method] = error

def pop_error(self, context):
name = inspect.currentframe().f_back.f_code.co_name
error: _Status | None = self.errors.pop(name, None)
if error:
context.abort_with_status(error)

def get_result_as_partial_result_sets(
self, sql: str
) -> [result_set.PartialResultSet]:
Expand Down Expand Up @@ -174,6 +186,7 @@ def __create_transaction(

def Commit(self, request, context):
self._requests.append(request)
self.mock_spanner.pop_error(context)
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
Expand Down
31 changes: 31 additions & 0 deletions tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,37 @@
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
import grpc
from google.rpc import code_pb2
from google.rpc import status_pb2
from google.rpc.error_details_pb2 import RetryInfo
from google.protobuf.duration_pb2 import Duration
from grpc_status._common import code_to_grpc_status_code
from grpc_status.rpc_status import _Status


# Creates an aborted status with the smallest possible retry delay.
def aborted_status() -> _Status:
error = status_pb2.Status(
code=code_pb2.ABORTED,
message="Transaction was aborted.",
)
retry_info = RetryInfo(retry_delay=Duration(seconds=0, nanos=1))
status = _Status(
code=code_to_grpc_status_code(error.code),
details=error.message,
trailing_metadata=(
("grpc-status-details-bin", error.SerializeToString()),
(
"google.rpc.retryinfo-bin",
retry_info.SerializeToString(),
),
),
)
return status


def add_error(method: str, error: status_pb2.Status):
MockServerTestBase.spanner_service.mock_spanner.add_error(method, error)


def add_result(sql: str, result: result_set.ResultSet):
Expand Down
50 changes: 50 additions & 0 deletions tests/mockserver_tests/test_aborted_transaction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud.spanner_v1 import (
BatchCreateSessionsRequest,
BeginTransactionRequest,
CommitRequest,
)
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
from google.cloud.spanner_v1.transaction import Transaction
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_error,
aborted_status,
)


class TestAbortedTransaction(MockServerTestBase):
def test_run_in_transaction_commit_aborted(self):
# Add an Aborted error for the Commit method on the mock server.
add_error(SpannerServicer.Commit.__name__, aborted_status())
# Run a transaction. The Commit method will return Aborted the first
# time that the transaction tries to commit. It will then be retried
# and succeed.
self.database.run_in_transaction(_insert_mutations)

# Verify that the transaction was retried.
requests = self.spanner_service.requests
self.assertEqual(5, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], BeginTransactionRequest))
self.assertTrue(isinstance(requests[2], CommitRequest))
# The transaction is aborted and retried.
self.assertTrue(isinstance(requests[3], BeginTransactionRequest))
self.assertTrue(isinstance(requests[4], CommitRequest))


def _insert_mutations(transaction: Transaction):
transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"])

0 comments on commit 1c0e4cf

Please sign in to comment.