Skip to content

Commit 6226a07

Browse files
committed
Add full support for protocol.LAST_COMMIT_INFO (#45)
Previously the LAST_COMMIT_INFO feature was only supported at the protocol level: the driver always reported there was no information available about the last commit. Last commit ensures "read your own writes" capability for committed transactions, even when connections are using different TEs. Implement an internal structure to store last commit information: - The information is saved as a process-wide set of data - Allows the same process to connect to different databases, and to connect to the same database multiple times (to the same or to different TEs). In order to support Python clients that may use threading with different connections, use a Lock() to access the last commit details. Python threading doesn't provide reader/writer locks; rather than write one we'll use a simple mutex and assume this won't be a performance bottleneck. Add a reset() method to forget all information about previously connected databases. This is not required since each database has its own UUID, but can free up resources for long-running applications that connect to lots of different databases.
1 parent fd65998 commit 6226a07

File tree

2 files changed

+63
-9
lines changed

2 files changed

+63
-9
lines changed

pynuodb/connection.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
connect -- Creates a connection object.
1313
"""
1414

15-
__all__ = ['apilevel', 'threadsafety', 'paramstyle', 'connect', 'Connection']
15+
__all__ = ['apilevel', 'threadsafety', 'paramstyle', 'connect',
16+
'reset', 'Connection']
1617

1718
import os
1819
import copy
@@ -59,6 +60,17 @@ def connect(database=None, # type: Optional[str]
5960
options=options, **kwargs)
6061

6162

63+
def reset():
64+
# type: () -> None
65+
"""Reset the module to its initial state.
66+
67+
Forget any global state maintained by the module.
68+
NOTE: this does not impact existing connections or cursors.
69+
It only impacts new connections.
70+
"""
71+
encodedsession.EncodedSession.reset()
72+
73+
6274
class Connection(object):
6375
"""An established SQL connection with a NuoDB database.
6476

pynuodb/encodedsession.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
import struct
1616
import decimal
1717
import sys
18+
import threading
1819

1920
try:
2021
from typing import Any, Collection, Dict, List # pylint: disable=unused-import
21-
from typing import Mapping, Optional # pylint: disable=unused-import
22+
from typing import Mapping, Optional, Tuple # pylint: disable=unused-import
2223
except ImportError:
2324
pass
2425

@@ -103,6 +104,22 @@ class EncodedSession(session.Session): # pylint: disable=too-many-public-method
103104
__lastNodeId = 0
104105
__lastCommitSeq = 0
105106

107+
__dbinfo = None # type: Dict[int, Tuple[int, int]]
108+
109+
# Manage the last commit info
110+
111+
# If we decide this lock is causing performance issues we can implement
112+
# a reader/writer lock.
113+
__dblock = threading.Lock()
114+
__databases = {} # type: Dict[str, Dict[int, Tuple[int, int]]]
115+
116+
@staticmethod
117+
def reset():
118+
# type: () -> None
119+
"""Reset the EncodedSession global data."""
120+
with EncodedSession.__dblock:
121+
EncodedSession.__databases = {}
122+
106123
@property
107124
def db_uuid(self):
108125
# type: () -> Optional[uuid.UUID]
@@ -143,7 +160,7 @@ def __init__(self, host, service='SQL2', options=None, **kwargs):
143160
super(EncodedSession, self).__init__(host, service=service,
144161
options=options, **kwargs)
145162

146-
def open_database(self, db_name, password, parameters): # pylint: disable=too-many-branches
163+
def open_database(self, db_name, password, parameters): # pylint: disable=too-many-branches,too-many-statements
147164
# type: (str, str, Dict[str, str]) -> None
148165
"""Perform a handshake as a SQL client with a NuoDB TE.
149166
@@ -210,6 +227,12 @@ def open_database(self, db_name, password, parameters): # pylint: disable=too-m
210227
self.__connectedNodeID = self.getInt()
211228
self.__maxNodes = self.getInt()
212229

230+
dbid = str(self.db_uuid)
231+
with EncodedSession.__dblock:
232+
if dbid not in EncodedSession.__databases:
233+
EncodedSession.__databases[dbid] = {}
234+
self.__dbinfo = EncodedSession.__databases[dbid]
235+
213236
self.__sessionVersion = protocolVersion
214237

215238
if not self.tls_encrypted:
@@ -257,6 +280,13 @@ def send_close(self):
257280
self._exchangeMessages()
258281
self.close()
259282

283+
def __set_dbinfo(self, sid, txid, seqid):
284+
# type: (int, int, int) -> None
285+
with EncodedSession.__dblock:
286+
lci = self.__dbinfo.get(sid, (-1, -1))
287+
if seqid > lci[1]:
288+
self.__dbinfo[sid] = (txid, seqid)
289+
260290
def send_commit(self):
261291
# type: () -> int
262292
"""Commit an open transaction on this connection.
@@ -268,6 +298,7 @@ def send_commit(self):
268298
self.__lastTxnId = self.getInt()
269299
self.__lastNodeId = self.getInt()
270300
self.__lastCommitSeq = self.getInt()
301+
self.__set_dbinfo(self.__lastNodeId, self.__lastTxnId, self.__lastCommitSeq)
271302
return self.__lastTxnId
272303

273304
def send_rollback(self):
@@ -309,6 +340,13 @@ def create_statement(self):
309340
self._exchangeMessages()
310341
return statement.Statement(self.getInt())
311342

343+
def __execute_postfix(self):
344+
# type: () -> None
345+
txid = self.getInt()
346+
sid = self.getInt()
347+
seqid = self.getInt()
348+
self.__set_dbinfo(sid, txid, seqid)
349+
312350
def execute_statement(self, stmt, query):
313351
# type: (statement.Statement, str) -> statement.ExecutionResult
314352
"""Execute a query using the given statement.
@@ -322,6 +360,7 @@ def execute_statement(self, stmt, query):
322360

323361
result = self.getInt()
324362
rowcount = self.getInt()
363+
self.__execute_postfix()
325364

326365
return statement.ExecutionResult(stmt, result, rowcount)
327366

@@ -370,6 +409,7 @@ def execute_prepared_statement(
370409

371410
result = self.getInt()
372411
rowcount = self.getInt()
412+
self.__execute_postfix()
373413

374414
return statement.ExecutionResult(prepared_statement, result, rowcount)
375415

@@ -408,6 +448,8 @@ def execute_batch_prepared_statement(self, prepared_statement, param_lists):
408448
if error_string is not None:
409449
raise BatchError(error_string, results)
410450

451+
self.__execute_postfix()
452+
411453
return results
412454

413455
def fetch_result_set(self, stmt):
@@ -1127,7 +1169,12 @@ def _setup_statement(self, handle, msgId):
11271169
"""
11281170
self._putMessageId(msgId)
11291171
if self.__sessionVersion >= protocol.LAST_COMMIT_INFO:
1130-
self.putInt(self.getCommitInfo(self.__connectedNodeID))
1172+
with EncodedSession.__dblock:
1173+
self.putInt(len(self.__dbinfo))
1174+
for sid, tup in self.__dbinfo.items():
1175+
self.putInt(sid)
1176+
self.putInt(tup[0])
1177+
self.putInt(tup[1])
11311178
self.putInt(handle)
11321179

11331180
return self
@@ -1164,8 +1211,3 @@ def _takeBytes(self, length):
11641211
return self.__input[self.__inpos:self.__inpos + length]
11651212
finally:
11661213
self.__inpos += length
1167-
1168-
def getCommitInfo(self, nodeID): # pylint: disable=no-self-use
1169-
# type: (int) -> int
1170-
"""Return the last commit info. Does not support last commit."""
1171-
return 0 * nodeID

0 commit comments

Comments
 (0)