Skip to content

Commit d3ce683

Browse files
committed
test: support inline-begin in mock server
1 parent 259a78b commit d3ce683

File tree

2 files changed

+105
-10
lines changed

2 files changed

+105
-10
lines changed

google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818

1919
from google.protobuf import empty_pb2
2020
from grpc_status.rpc_status import _Status
21+
22+
from google.cloud.spanner_v1 import (
23+
TransactionOptions,
24+
ResultSetMetadata,
25+
ExecuteSqlRequest,
26+
ExecuteBatchDmlRequest,
27+
)
2128
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
2229
import google.cloud.spanner_v1.testing.spanner_database_admin_pb2_grpc as database_admin_grpc
2330
import google.cloud.spanner_v1.testing.spanner_pb2_grpc as spanner_grpc
@@ -51,23 +58,25 @@ def pop_error(self, context):
5158
context.abort_with_status(error)
5259

5360
def get_result_as_partial_result_sets(
54-
self, sql: str
61+
self, sql: str, started_transaction: transaction.Transaction
5562
) -> [result_set.PartialResultSet]:
5663
result: result_set.ResultSet = self.get_result(sql)
5764
partials = []
5865
first = True
5966
if len(result.rows) == 0:
6067
partial = result_set.PartialResultSet()
61-
partial.metadata = result.metadata
68+
partial.metadata = ResultSetMetadata(result.metadata)
6269
partials.append(partial)
6370
else:
6471
for row in result.rows:
6572
partial = result_set.PartialResultSet()
6673
if first:
67-
partial.metadata = result.metadata
74+
partial.metadata = ResultSetMetadata(result.metadata)
6875
partial.values.extend(row)
6976
partials.append(partial)
7077
partials[len(partials) - 1].stats = result.stats
78+
if started_transaction:
79+
partials[0].metadata.transaction = started_transaction
7180
return partials
7281

7382

@@ -129,22 +138,29 @@ def DeleteSession(self, request, context):
129138

130139
def ExecuteSql(self, request, context):
131140
self._requests.append(request)
132-
return result_set.ResultSet()
141+
self.mock_spanner.pop_error(context)
142+
started_transaction = self.__maybe_create_transaction(request)
143+
result: result_set.ResultSet = self.mock_spanner.get_result(request.sql)
144+
if started_transaction:
145+
result.metadata = ResultSetMetadata(result.metadata)
146+
result.metadata.transaction = started_transaction
147+
return result
133148

134149
def ExecuteStreamingSql(self, request, context):
135150
self._requests.append(request)
136-
partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql)
151+
self.mock_spanner.pop_error(context)
152+
started_transaction = self.__maybe_create_transaction(request)
153+
partials = self.mock_spanner.get_result_as_partial_result_sets(
154+
request.sql, started_transaction
155+
)
137156
for result in partials:
138157
yield result
139158

140159
def ExecuteBatchDml(self, request, context):
141160
self._requests.append(request)
161+
self.mock_spanner.pop_error(context)
142162
response = spanner.ExecuteBatchDmlResponse()
143-
started_transaction = None
144-
if not request.transaction.begin == transaction.TransactionOptions():
145-
started_transaction = self.__create_transaction(
146-
request.session, request.transaction.begin
147-
)
163+
started_transaction = self.__maybe_create_transaction(request)
148164
first = True
149165
for statement in request.statements:
150166
result = self.mock_spanner.get_result(statement.sql)
@@ -170,6 +186,16 @@ def BeginTransaction(self, request, context):
170186
self._requests.append(request)
171187
return self.__create_transaction(request.session, request.options)
172188

189+
def __maybe_create_transaction(
190+
self, request: ExecuteSqlRequest | ExecuteBatchDmlRequest
191+
):
192+
started_transaction = None
193+
if not request.transaction.begin == TransactionOptions():
194+
started_transaction = self.__create_transaction(
195+
request.session, request.transaction.begin
196+
)
197+
return started_transaction
198+
173199
def __create_transaction(
174200
self, session: str, options: transaction.TransactionOptions
175201
) -> transaction.Transaction:

tests/mockserver_tests/test_aborted_transaction.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616
BatchCreateSessionsRequest,
1717
BeginTransactionRequest,
1818
CommitRequest,
19+
ExecuteSqlRequest,
20+
TypeCode,
21+
ExecuteBatchDmlRequest,
1922
)
2023
from google.cloud.spanner_v1.testing.mock_spanner import SpannerServicer
2124
from google.cloud.spanner_v1.transaction import Transaction
2225
from tests.mockserver_tests.mock_server_test_base import (
2326
MockServerTestBase,
2427
add_error,
2528
aborted_status,
29+
add_update_count,
30+
add_single_result,
2631
)
2732

2833

@@ -45,6 +50,70 @@ def test_run_in_transaction_commit_aborted(self):
4550
self.assertTrue(isinstance(requests[3], BeginTransactionRequest))
4651
self.assertTrue(isinstance(requests[4], CommitRequest))
4752

53+
def test_run_in_transaction_update_aborted(self):
54+
add_update_count("update my_table set my_col=1 where id=2", 1)
55+
add_error(SpannerServicer.ExecuteSql.__name__, aborted_status())
56+
self.database.run_in_transaction(_execute_update)
57+
58+
# Verify that the transaction was retried.
59+
requests = self.spanner_service.requests
60+
self.assertEqual(4, len(requests), msg=requests)
61+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
62+
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
63+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
64+
self.assertTrue(isinstance(requests[3], CommitRequest))
65+
66+
def test_run_in_transaction_query_aborted(self):
67+
add_single_result(
68+
"select value from my_table where id=1",
69+
"value",
70+
TypeCode.STRING,
71+
"my-value",
72+
)
73+
add_error(SpannerServicer.ExecuteStreamingSql.__name__, aborted_status())
74+
self.database.run_in_transaction(_execute_query)
75+
76+
# Verify that the transaction was retried.
77+
requests = self.spanner_service.requests
78+
self.assertEqual(4, len(requests), msg=requests)
79+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
80+
self.assertTrue(isinstance(requests[1], ExecuteSqlRequest))
81+
self.assertTrue(isinstance(requests[2], ExecuteSqlRequest))
82+
self.assertTrue(isinstance(requests[3], CommitRequest))
83+
84+
def test_run_in_transaction_batch_dml_aborted(self):
85+
add_update_count("update my_table set my_col=1 where id=1", 1)
86+
add_update_count("update my_table set my_col=1 where id=2", 1)
87+
add_error(SpannerServicer.ExecuteBatchDml.__name__, aborted_status())
88+
self.database.run_in_transaction(_execute_batch_dml)
89+
90+
# Verify that the transaction was retried.
91+
requests = self.spanner_service.requests
92+
self.assertEqual(4, len(requests), msg=requests)
93+
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
94+
self.assertTrue(isinstance(requests[1], ExecuteBatchDmlRequest))
95+
self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest))
96+
self.assertTrue(isinstance(requests[3], CommitRequest))
97+
4898

4999
def _insert_mutations(transaction: Transaction):
50100
transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"])
101+
102+
103+
def _execute_update(transaction: Transaction):
104+
transaction.execute_update("update my_table set my_col=1 where id=2")
105+
106+
107+
def _execute_query(transaction: Transaction):
108+
rows = transaction.execute_sql("select value from my_table where id=1")
109+
for _ in rows:
110+
pass
111+
112+
113+
def _execute_batch_dml(transaction: Transaction):
114+
transaction.batch_update(
115+
[
116+
"update my_table set my_col=1 where id=1",
117+
"update my_table set my_col=1 where id=2",
118+
]
119+
)

0 commit comments

Comments
 (0)