Skip to content

Add full support for protocol.LAST_COMMIT_INFO (#45) #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion pynuodb/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
connect -- Creates a connection object.
"""

__all__ = ['apilevel', 'threadsafety', 'paramstyle', 'connect', 'Connection']
__all__ = ['apilevel', 'threadsafety', 'paramstyle', 'connect',
'reset', 'Connection']

import os
import copy
Expand Down Expand Up @@ -59,6 +60,17 @@ def connect(database=None, # type: Optional[str]
options=options, **kwargs)


def reset():
# type: () -> None
"""Reset the module to its initial state.

Forget any global state maintained by the module.
NOTE: this does not impact existing connections or cursors.
It only impacts new connections.
"""
encodedsession.EncodedSession.reset()


class Connection(object):
"""An established SQL connection with a NuoDB database.

Expand Down
58 changes: 50 additions & 8 deletions pynuodb/encodedsession.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
import struct
import decimal
import sys
import threading

try:
from typing import Any, Collection, Dict, List # pylint: disable=unused-import
from typing import Mapping, Optional # pylint: disable=unused-import
from typing import Mapping, Optional, Tuple # pylint: disable=unused-import
except ImportError:
pass

Expand Down Expand Up @@ -103,6 +104,22 @@ class EncodedSession(session.Session): # pylint: disable=too-many-public-method
__lastNodeId = 0
__lastCommitSeq = 0

__dbinfo = None # type: Dict[int, Tuple[int, int]]

# Manage the last commit info

# If we decide this lock is causing performance issues we can implement
# a reader/writer lock.
__dblock = threading.Lock()
__databases = {} # type: Dict[str, Dict[int, Tuple[int, int]]]

@staticmethod
def reset():
# type: () -> None
"""Reset the EncodedSession global data."""
with EncodedSession.__dblock:
EncodedSession.__databases = {}

@property
def db_uuid(self):
# type: () -> Optional[uuid.UUID]
Expand Down Expand Up @@ -143,7 +160,7 @@ def __init__(self, host, service='SQL2', options=None, **kwargs):
super(EncodedSession, self).__init__(host, service=service,
options=options, **kwargs)

def open_database(self, db_name, password, parameters): # pylint: disable=too-many-branches
def open_database(self, db_name, password, parameters): # pylint: disable=too-many-branches,too-many-statements
# type: (str, str, Dict[str, str]) -> None
"""Perform a handshake as a SQL client with a NuoDB TE.

Expand Down Expand Up @@ -210,6 +227,12 @@ def open_database(self, db_name, password, parameters): # pylint: disable=too-m
self.__connectedNodeID = self.getInt()
self.__maxNodes = self.getInt()

dbid = str(self.db_uuid)
with EncodedSession.__dblock:
if dbid not in EncodedSession.__databases:
EncodedSession.__databases[dbid] = {}
self.__dbinfo = EncodedSession.__databases[dbid]

self.__sessionVersion = protocolVersion

if not self.tls_encrypted:
Expand Down Expand Up @@ -257,6 +280,13 @@ def send_close(self):
self._exchangeMessages()
self.close()

def __set_dbinfo(self, sid, txid, seqid):
# type: (int, int, int) -> None
with EncodedSession.__dblock:
lci = self.__dbinfo.get(sid, (-1, -1))
if seqid > lci[1]:
self.__dbinfo[sid] = (txid, seqid)

def send_commit(self):
# type: () -> int
"""Commit an open transaction on this connection.
Expand All @@ -268,6 +298,7 @@ def send_commit(self):
self.__lastTxnId = self.getInt()
self.__lastNodeId = self.getInt()
self.__lastCommitSeq = self.getInt()
self.__set_dbinfo(self.__lastNodeId, self.__lastTxnId, self.__lastCommitSeq)
return self.__lastTxnId

def send_rollback(self):
Expand Down Expand Up @@ -309,6 +340,13 @@ def create_statement(self):
self._exchangeMessages()
return statement.Statement(self.getInt())

def __execute_postfix(self):
# type: () -> None
txid = self.getInt()
sid = self.getInt()
seqid = self.getInt()
self.__set_dbinfo(sid, txid, seqid)

def execute_statement(self, stmt, query):
# type: (statement.Statement, str) -> statement.ExecutionResult
"""Execute a query using the given statement.
Expand All @@ -322,6 +360,7 @@ def execute_statement(self, stmt, query):

result = self.getInt()
rowcount = self.getInt()
self.__execute_postfix()

return statement.ExecutionResult(stmt, result, rowcount)

Expand Down Expand Up @@ -370,6 +409,7 @@ def execute_prepared_statement(

result = self.getInt()
rowcount = self.getInt()
self.__execute_postfix()

return statement.ExecutionResult(prepared_statement, result, rowcount)

Expand Down Expand Up @@ -408,6 +448,8 @@ def execute_batch_prepared_statement(self, prepared_statement, param_lists):
if error_string is not None:
raise BatchError(error_string, results)

self.__execute_postfix()

return results

def fetch_result_set(self, stmt):
Expand Down Expand Up @@ -1127,7 +1169,12 @@ def _setup_statement(self, handle, msgId):
"""
self._putMessageId(msgId)
if self.__sessionVersion >= protocol.LAST_COMMIT_INFO:
self.putInt(self.getCommitInfo(self.__connectedNodeID))
with EncodedSession.__dblock:
self.putInt(len(self.__dbinfo))
for sid, tup in self.__dbinfo.items():
self.putInt(sid)
self.putInt(tup[0])
self.putInt(tup[1])
self.putInt(handle)

return self
Expand Down Expand Up @@ -1164,8 +1211,3 @@ def _takeBytes(self, length):
return self.__input[self.__inpos:self.__inpos + length]
finally:
self.__inpos += length

def getCommitInfo(self, nodeID): # pylint: disable=no-self-use
# type: (int) -> int
"""Return the last commit info. Does not support last commit."""
return 0 * nodeID
10 changes: 4 additions & 6 deletions pynuodb/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,14 +470,12 @@ def send(self, message):
# don't have to reallocate this entire buffer, but it is unreliable.
buf = lenbuf + data
view = memoryview(buf)
start = 0
left = len(buf)
end = len(buf)
cur = 0

try:
while left > 0:
sent = sock.send(view[start:left])
start += sent
left -= sent
while cur < end:
cur += sock.send(view[cur:])
except Exception:
self.close()
raise
Expand Down