15
15
import struct
16
16
import decimal
17
17
import sys
18
+ import threading
18
19
19
20
try :
20
21
from typing import Any , Collection , Dict , List # pylint: disable=unused-import
21
- from typing import Mapping , Optional # pylint: disable=unused-import
22
+ from typing import Mapping , Optional , Tuple # pylint: disable=unused-import
22
23
except ImportError :
23
24
pass
24
25
@@ -103,6 +104,23 @@ class EncodedSession(session.Session): # pylint: disable=too-many-public-method
103
104
__lastNodeId = 0
104
105
__lastCommitSeq = 0
105
106
107
+ # Manage the last commit info
108
+
109
+ class DbInfo (object ):
110
+ """Last commit information for a given database."""
111
+ size = None # type: int
112
+ info = None # type: Dict[int, Tuple[int, int]]
113
+
114
+ def __init__ (self , size ):
115
+ # type: (int) -> None
116
+ self .size = size
117
+ self .info = {}
118
+
119
+ # If we decide this lock is causing performance issues we can implement
120
+ # a reader/writer lock.
121
+ __dblock = threading .Lock ()
122
+ __databases = None # type: Dict[str, DbInfo]
123
+
106
124
@property
107
125
def db_uuid (self ):
108
126
# type: () -> Optional[uuid.UUID]
@@ -143,7 +161,7 @@ def __init__(self, host, service='SQL2', options=None, **kwargs):
143
161
super (EncodedSession , self ).__init__ (host , service = service ,
144
162
options = options , ** kwargs )
145
163
146
- def open_database (self , db_name , password , parameters ): # pylint: disable=too-many-branches
164
+ def open_database (self , db_name , password , parameters ): # pylint: disable=too-many-branches,too-many-statements
147
165
# type: (str, str, Dict[str, str]) -> None
148
166
"""Perform a handshake as a SQL client with a NuoDB TE.
149
167
@@ -210,6 +228,13 @@ def open_database(self, db_name, password, parameters): # pylint: disable=too-m
210
228
self .__connectedNodeID = self .getInt ()
211
229
self .__maxNodes = self .getInt ()
212
230
231
+ dbid = str (self .db_uuid )
232
+ with EncodedSession .__dblock :
233
+ if EncodedSession .__databases is None :
234
+ EncodedSession .__databases = {}
235
+ if dbid not in EncodedSession .__databases :
236
+ EncodedSession .__databases [dbid ] = EncodedSession .DbInfo (self .__maxNodes )
237
+
213
238
self .__sessionVersion = protocolVersion
214
239
215
240
if not self .tls_encrypted :
@@ -257,6 +282,14 @@ def send_close(self):
257
282
self ._exchangeMessages ()
258
283
self .close ()
259
284
285
+ def __set_dbinfo (self , sid , txid , seqid ):
286
+ # type: (int, int, int) -> None
287
+ with EncodedSession .__dblock :
288
+ info = EncodedSession .__databases [str (self .db_uuid )].info
289
+ lci = info .get (sid , (- 1 , - 1 ))
290
+ if seqid > lci [1 ]:
291
+ info [sid ] = (txid , seqid )
292
+
260
293
def send_commit (self ):
261
294
# type: () -> int
262
295
"""Commit an open transaction on this connection.
@@ -268,6 +301,7 @@ def send_commit(self):
268
301
self .__lastTxnId = self .getInt ()
269
302
self .__lastNodeId = self .getInt ()
270
303
self .__lastCommitSeq = self .getInt ()
304
+ self .__set_dbinfo (self .__lastNodeId , self .__lastTxnId , self .__lastCommitSeq )
271
305
return self .__lastTxnId
272
306
273
307
def send_rollback (self ):
@@ -309,6 +343,13 @@ def create_statement(self):
309
343
self ._exchangeMessages ()
310
344
return statement .Statement (self .getInt ())
311
345
346
+ def __execute_postfix (self ):
347
+ # type: () -> None
348
+ txid = self .getInt ()
349
+ sid = self .getInt ()
350
+ seqid = self .getInt ()
351
+ self .__set_dbinfo (sid , txid , seqid )
352
+
312
353
def execute_statement (self , stmt , query ):
313
354
# type: (statement.Statement, str) -> statement.ExecutionResult
314
355
"""Execute a query using the given statement.
@@ -322,6 +363,7 @@ def execute_statement(self, stmt, query):
322
363
323
364
result = self .getInt ()
324
365
rowcount = self .getInt ()
366
+ self .__execute_postfix ()
325
367
326
368
return statement .ExecutionResult (stmt , result , rowcount )
327
369
@@ -370,6 +412,7 @@ def execute_prepared_statement(
370
412
371
413
result = self .getInt ()
372
414
rowcount = self .getInt ()
415
+ self .__execute_postfix ()
373
416
374
417
return statement .ExecutionResult (prepared_statement , result , rowcount )
375
418
@@ -408,6 +451,8 @@ def execute_batch_prepared_statement(self, prepared_statement, param_lists):
408
451
if error_string is not None :
409
452
raise BatchError (error_string , results )
410
453
454
+ self .__execute_postfix ()
455
+
411
456
return results
412
457
413
458
def fetch_result_set (self , stmt ):
@@ -1127,7 +1172,13 @@ def _setup_statement(self, handle, msgId):
1127
1172
"""
1128
1173
self ._putMessageId (msgId )
1129
1174
if self .__sessionVersion >= protocol .LAST_COMMIT_INFO :
1130
- self .putInt (self .getCommitInfo (self .__connectedNodeID ))
1175
+ with EncodedSession .__dblock :
1176
+ dbinfo = EncodedSession .__databases [str (self .db_uuid )]
1177
+ self .putInt (len (dbinfo .info ))
1178
+ for sid , tup in dbinfo .info .items ():
1179
+ self .putInt (sid )
1180
+ self .putInt (tup [0 ])
1181
+ self .putInt (tup [1 ])
1131
1182
self .putInt (handle )
1132
1183
1133
1184
return self
@@ -1164,8 +1215,3 @@ def _takeBytes(self, length):
1164
1215
return self .__input [self .__inpos :self .__inpos + length ]
1165
1216
finally :
1166
1217
self .__inpos += length
1167
-
1168
- def getCommitInfo (self , nodeID ): # pylint: disable=no-self-use
1169
- # type: (int) -> int
1170
- """Return the last commit info. Does not support last commit."""
1171
- return 0 * nodeID
0 commit comments