18
18
19
19
from google .protobuf import empty_pb2
20
20
from grpc_status .rpc_status import _Status
21
+
22
+ from google .cloud .spanner_v1 import (
23
+ TransactionOptions ,
24
+ ResultSetMetadata ,
25
+ ExecuteSqlRequest ,
26
+ ExecuteBatchDmlRequest ,
27
+ )
21
28
from google .cloud .spanner_v1 .testing .mock_database_admin import DatabaseAdminServicer
22
29
import google .cloud .spanner_v1 .testing .spanner_database_admin_pb2_grpc as database_admin_grpc
23
30
import google .cloud .spanner_v1 .testing .spanner_pb2_grpc as spanner_grpc
@@ -51,23 +58,25 @@ def pop_error(self, context):
51
58
context .abort_with_status (error )
52
59
53
60
def get_result_as_partial_result_sets (
54
- self , sql : str
61
+ self , sql : str , started_transaction : transaction . Transaction
55
62
) -> [result_set .PartialResultSet ]:
56
63
result : result_set .ResultSet = self .get_result (sql )
57
64
partials = []
58
65
first = True
59
66
if len (result .rows ) == 0 :
60
67
partial = result_set .PartialResultSet ()
61
- partial .metadata = result .metadata
68
+ partial .metadata = ResultSetMetadata ( result .metadata )
62
69
partials .append (partial )
63
70
else :
64
71
for row in result .rows :
65
72
partial = result_set .PartialResultSet ()
66
73
if first :
67
- partial .metadata = result .metadata
74
+ partial .metadata = ResultSetMetadata ( result .metadata )
68
75
partial .values .extend (row )
69
76
partials .append (partial )
70
77
partials [len (partials ) - 1 ].stats = result .stats
78
+ if started_transaction :
79
+ partials [0 ].metadata .transaction = started_transaction
71
80
return partials
72
81
73
82
@@ -129,22 +138,29 @@ def DeleteSession(self, request, context):
129
138
130
139
def ExecuteSql (self , request , context ):
131
140
self ._requests .append (request )
132
- return result_set .ResultSet ()
141
+ self .mock_spanner .pop_error (context )
142
+ started_transaction = self .__maybe_create_transaction (request )
143
+ result : result_set .ResultSet = self .mock_spanner .get_result (request .sql )
144
+ if started_transaction :
145
+ result .metadata = ResultSetMetadata (result .metadata )
146
+ result .metadata .transaction = started_transaction
147
+ return result
133
148
134
149
def ExecuteStreamingSql (self , request , context ):
135
150
self ._requests .append (request )
136
- partials = self .mock_spanner .get_result_as_partial_result_sets (request .sql )
151
+ self .mock_spanner .pop_error (context )
152
+ started_transaction = self .__maybe_create_transaction (request )
153
+ partials = self .mock_spanner .get_result_as_partial_result_sets (
154
+ request .sql , started_transaction
155
+ )
137
156
for result in partials :
138
157
yield result
139
158
140
159
def ExecuteBatchDml (self , request , context ):
141
160
self ._requests .append (request )
161
+ self .mock_spanner .pop_error (context )
142
162
response = spanner .ExecuteBatchDmlResponse ()
143
- started_transaction = None
144
- if not request .transaction .begin == transaction .TransactionOptions ():
145
- started_transaction = self .__create_transaction (
146
- request .session , request .transaction .begin
147
- )
163
+ started_transaction = self .__maybe_create_transaction (request )
148
164
first = True
149
165
for statement in request .statements :
150
166
result = self .mock_spanner .get_result (statement .sql )
@@ -170,6 +186,16 @@ def BeginTransaction(self, request, context):
170
186
self ._requests .append (request )
171
187
return self .__create_transaction (request .session , request .options )
172
188
189
+ def __maybe_create_transaction (
190
+ self , request : ExecuteSqlRequest | ExecuteBatchDmlRequest
191
+ ):
192
+ started_transaction = None
193
+ if not request .transaction .begin == TransactionOptions ():
194
+ started_transaction = self .__create_transaction (
195
+ request .session , request .transaction .begin
196
+ )
197
+ return started_transaction
198
+
173
199
def __create_transaction (
174
200
self , session : str , options : transaction .TransactionOptions
175
201
) -> transaction .Transaction :
0 commit comments