Skip to content

Commit 0dcb97c

Browse files
committed
Add full support for protocol.LAST_COMMIT_INFO (#45)
Previously the LAST_COMMIT_INFO feature was only supported at the protocol level: the driver always reported there was no information available about the last commit. Last commit ensures "read your own writes" capability for committed transactions, even when connections are using different TEs. Implement an internal structure to store last commit information: - The information is saved as a process-wide set of data - Allows the same process to connect to different databases, and to connect to the same database multiple times (to the same or to different TEs). In order to support Python clients that may use threading with different connections, use a Lock() to access the last commit details. Python threading doesn't provide reader/writer locks; rather than write one we'll use a simple mutex and assume this won't be a performance bottleneck.
1 parent 6b66753 commit 0dcb97c

File tree

1 file changed

+54
-8
lines changed

1 file changed

+54
-8
lines changed

pynuodb/encodedsession.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import struct
1616
import decimal
1717
import sys
18+
import threading
1819

1920
try:
2021
from typing import Any, Collection, Dict, List # pylint: disable=unused-import
21-
from typing import Mapping, Optional # pylint: disable=unused-import
22+
from typing import Mapping, Optional, Tuple # pylint: disable=unused-import
2223
except ImportError:
2324
pass
2425

@@ -103,6 +104,23 @@ class EncodedSession(session.Session): # pylint: disable=too-many-public-method
103104
__lastNodeId = 0
104105
__lastCommitSeq = 0
105106

107+
# Manage the last commit info
108+
109+
class DbInfo(object):
110+
"""Last commit information for a given database."""
111+
size = None # type: int
112+
info = None # type: Dict[int, Tuple[int, int]]
113+
114+
def __init__(self, size):
115+
# type: (int) -> None
116+
self.size = size
117+
self.info = {}
118+
119+
# If we decide this lock is causing performance issues we can implement
120+
# a reader/writer lock.
121+
__dblock = threading.Lock()
122+
__databases = None # type: Dict[str, DbInfo]
123+
106124
@property
107125
def db_uuid(self):
108126
# type: () -> Optional[uuid.UUID]
@@ -143,7 +161,7 @@ def __init__(self, host, service='SQL2', options=None, **kwargs):
143161
super(EncodedSession, self).__init__(host, service=service,
144162
options=options, **kwargs)
145163

146-
def open_database(self, db_name, password, parameters): # pylint: disable=too-many-branches
164+
def open_database(self, db_name, password, parameters): # pylint: disable=too-many-branches,too-many-statements
147165
# type: (str, str, Dict[str, str]) -> None
148166
"""Perform a handshake as a SQL client with a NuoDB TE.
149167
@@ -210,6 +228,13 @@ def open_database(self, db_name, password, parameters): # pylint: disable=too-m
210228
self.__connectedNodeID = self.getInt()
211229
self.__maxNodes = self.getInt()
212230

231+
dbid = str(self.db_uuid)
232+
with EncodedSession.__dblock:
233+
if EncodedSession.__databases is None:
234+
EncodedSession.__databases = {}
235+
if dbid not in EncodedSession.__databases:
236+
EncodedSession.__databases[dbid] = EncodedSession.DbInfo(self.__maxNodes)
237+
213238
self.__sessionVersion = protocolVersion
214239

215240
if not self.tls_encrypted:
@@ -257,6 +282,14 @@ def send_close(self):
257282
self._exchangeMessages()
258283
self.close()
259284

285+
def __set_dbinfo(self, sid, txid, seqid):
286+
# type: (int, int, int) -> None
287+
with EncodedSession.__dblock:
288+
info = EncodedSession.__databases[str(self.db_uuid)].info
289+
lci = info.get(sid, (-1, -1))
290+
if seqid > lci[1]:
291+
info[sid] = (txid, seqid)
292+
260293
def send_commit(self):
261294
# type: () -> int
262295
"""Commit an open transaction on this connection.
@@ -268,6 +301,7 @@ def send_commit(self):
268301
self.__lastTxnId = self.getInt()
269302
self.__lastNodeId = self.getInt()
270303
self.__lastCommitSeq = self.getInt()
304+
self.__set_dbinfo(self.__lastNodeId, self.__lastTxnId, self.__lastCommitSeq)
271305
return self.__lastTxnId
272306

273307
def send_rollback(self):
@@ -309,6 +343,13 @@ def create_statement(self):
309343
self._exchangeMessages()
310344
return statement.Statement(self.getInt())
311345

346+
def __execute_postfix(self):
347+
# type: () -> None
348+
txid = self.getInt()
349+
sid = self.getInt()
350+
seqid = self.getInt()
351+
self.__set_dbinfo(sid, txid, seqid)
352+
312353
def execute_statement(self, stmt, query):
313354
# type: (statement.Statement, str) -> statement.ExecutionResult
314355
"""Execute a query using the given statement.
@@ -322,6 +363,7 @@ def execute_statement(self, stmt, query):
322363

323364
result = self.getInt()
324365
rowcount = self.getInt()
366+
self.__execute_postfix()
325367

326368
return statement.ExecutionResult(stmt, result, rowcount)
327369

@@ -370,6 +412,7 @@ def execute_prepared_statement(
370412

371413
result = self.getInt()
372414
rowcount = self.getInt()
415+
self.__execute_postfix()
373416

374417
return statement.ExecutionResult(prepared_statement, result, rowcount)
375418

@@ -408,6 +451,8 @@ def execute_batch_prepared_statement(self, prepared_statement, param_lists):
408451
if error_string is not None:
409452
raise BatchError(error_string, results)
410453

454+
self.__execute_postfix()
455+
411456
return results
412457

413458
def fetch_result_set(self, stmt):
@@ -1127,7 +1172,13 @@ def _setup_statement(self, handle, msgId):
11271172
"""
11281173
self._putMessageId(msgId)
11291174
if self.__sessionVersion >= protocol.LAST_COMMIT_INFO:
1130-
self.putInt(self.getCommitInfo(self.__connectedNodeID))
1175+
with EncodedSession.__dblock:
1176+
dbinfo = EncodedSession.__databases[str(self.db_uuid)]
1177+
self.putInt(len(dbinfo.info))
1178+
for sid, tup in dbinfo.info.items():
1179+
self.putInt(sid)
1180+
self.putInt(tup[0])
1181+
self.putInt(tup[1])
11311182
self.putInt(handle)
11321183

11331184
return self
@@ -1164,8 +1215,3 @@ def _takeBytes(self, length):
11641215
return self.__input[self.__inpos:self.__inpos + length]
11651216
finally:
11661217
self.__inpos += length
1167-
1168-
def getCommitInfo(self, nodeID): # pylint: disable=no-self-use
1169-
# type: (int) -> int
1170-
"""Return the last commit info. Does not support last commit."""
1171-
return 0 * nodeID

0 commit comments

Comments
 (0)