diff --git a/pynuodb/connection.py b/pynuodb/connection.py index a69d744..4d99ae6 100644 --- a/pynuodb/connection.py +++ b/pynuodb/connection.py @@ -12,7 +12,8 @@ connect -- Creates a connection object. """ -__all__ = ['apilevel', 'threadsafety', 'paramstyle', 'connect', 'Connection'] +__all__ = ['apilevel', 'threadsafety', 'paramstyle', 'connect', + 'reset', 'Connection'] import os import copy @@ -59,6 +60,17 @@ def connect(database=None, # type: Optional[str] options=options, **kwargs) +def reset(): + # type: () -> None + """Reset the module to its initial state. + + Forget any global state maintained by the module. + NOTE: this does not impact existing connections or cursors. + It only impacts new connections. + """ + encodedsession.EncodedSession.reset() + + class Connection(object): """An established SQL connection with a NuoDB database. diff --git a/pynuodb/encodedsession.py b/pynuodb/encodedsession.py index 9aa5c20..8d3b6db 100644 --- a/pynuodb/encodedsession.py +++ b/pynuodb/encodedsession.py @@ -15,10 +15,11 @@ import struct import decimal import sys +import threading try: from typing import Any, Collection, Dict, List # pylint: disable=unused-import - from typing import Mapping, Optional # pylint: disable=unused-import + from typing import Mapping, Optional, Tuple # pylint: disable=unused-import except ImportError: pass @@ -103,6 +104,22 @@ class EncodedSession(session.Session): # pylint: disable=too-many-public-method __lastNodeId = 0 __lastCommitSeq = 0 + __dbinfo = None # type: Dict[int, Tuple[int, int]] + + # Manage the last commit info + + # If we decide this lock is causing performance issues we can implement + # a reader/writer lock. + __dblock = threading.Lock() + __databases = {} # type: Dict[str, Dict[int, Tuple[int, int]]] + + @staticmethod + def reset(): + # type: () -> None + """Reset the EncodedSession global data.""" + with EncodedSession.__dblock: + EncodedSession.__databases = {} + @property def db_uuid(self): # type: () -> Optional[uuid.UUID] @@ -143,7 +160,7 @@ def __init__(self, host, service='SQL2', options=None, **kwargs): super(EncodedSession, self).__init__(host, service=service, options=options, **kwargs) - def open_database(self, db_name, password, parameters): # pylint: disable=too-many-branches + def open_database(self, db_name, password, parameters): # pylint: disable=too-many-branches,too-many-statements # type: (str, str, Dict[str, str]) -> None """Perform a handshake as a SQL client with a NuoDB TE. @@ -210,6 +227,12 @@ def open_database(self, db_name, password, parameters): # pylint: disable=too-m self.__connectedNodeID = self.getInt() self.__maxNodes = self.getInt() + dbid = str(self.db_uuid) + with EncodedSession.__dblock: + if dbid not in EncodedSession.__databases: + EncodedSession.__databases[dbid] = {} + self.__dbinfo = EncodedSession.__databases[dbid] + self.__sessionVersion = protocolVersion if not self.tls_encrypted: @@ -257,6 +280,13 @@ def send_close(self): self._exchangeMessages() self.close() + def __set_dbinfo(self, sid, txid, seqid): + # type: (int, int, int) -> None + with EncodedSession.__dblock: + lci = self.__dbinfo.get(sid, (-1, -1)) + if seqid > lci[1]: + self.__dbinfo[sid] = (txid, seqid) + def send_commit(self): # type: () -> int """Commit an open transaction on this connection. @@ -268,6 +298,7 @@ def send_commit(self): self.__lastTxnId = self.getInt() self.__lastNodeId = self.getInt() self.__lastCommitSeq = self.getInt() + self.__set_dbinfo(self.__lastNodeId, self.__lastTxnId, self.__lastCommitSeq) return self.__lastTxnId def send_rollback(self): @@ -309,6 +340,13 @@ def create_statement(self): self._exchangeMessages() return statement.Statement(self.getInt()) + def __execute_postfix(self): + # type: () -> None + txid = self.getInt() + sid = self.getInt() + seqid = self.getInt() + self.__set_dbinfo(sid, txid, seqid) + def execute_statement(self, stmt, query): # type: (statement.Statement, str) -> statement.ExecutionResult """Execute a query using the given statement. @@ -322,6 +360,7 @@ def execute_statement(self, stmt, query): result = self.getInt() rowcount = self.getInt() + self.__execute_postfix() return statement.ExecutionResult(stmt, result, rowcount) @@ -370,6 +409,7 @@ def execute_prepared_statement( result = self.getInt() rowcount = self.getInt() + self.__execute_postfix() return statement.ExecutionResult(prepared_statement, result, rowcount) @@ -408,6 +448,8 @@ def execute_batch_prepared_statement(self, prepared_statement, param_lists): if error_string is not None: raise BatchError(error_string, results) + self.__execute_postfix() + return results def fetch_result_set(self, stmt): @@ -1127,7 +1169,12 @@ def _setup_statement(self, handle, msgId): """ self._putMessageId(msgId) if self.__sessionVersion >= protocol.LAST_COMMIT_INFO: - self.putInt(self.getCommitInfo(self.__connectedNodeID)) + with EncodedSession.__dblock: + self.putInt(len(self.__dbinfo)) + for sid, tup in self.__dbinfo.items(): + self.putInt(sid) + self.putInt(tup[0]) + self.putInt(tup[1]) self.putInt(handle) return self @@ -1164,8 +1211,3 @@ def _takeBytes(self, length): return self.__input[self.__inpos:self.__inpos + length] finally: self.__inpos += length - - def getCommitInfo(self, nodeID): # pylint: disable=no-self-use - # type: (int) -> int - """Return the last commit info. Does not support last commit.""" - return 0 * nodeID diff --git a/pynuodb/session.py b/pynuodb/session.py index 9e0f51e..b9b4899 100644 --- a/pynuodb/session.py +++ b/pynuodb/session.py @@ -470,14 +470,12 @@ def send(self, message): # don't have to reallocate this entire buffer, but it is unreliable. buf = lenbuf + data view = memoryview(buf) - start = 0 - left = len(buf) + end = len(buf) + cur = 0 try: - while left > 0: - sent = sock.send(view[start:left]) - start += sent - left -= sent + while cur < end: + cur += sock.send(view[cur:]) except Exception: self.close() raise