@@ -43,6 +43,7 @@ def __init__(self, url, **kwargs):
43
43
import os
44
44
import re
45
45
import sqlalchemy
46
+ import sqlalchemy .orm
46
47
import sqlite3
47
48
48
49
# Get logger
@@ -59,6 +60,11 @@ def __init__(self, url, **kwargs):
59
60
# Create engine, disabling SQLAlchemy's own autocommit mode, raising exception if back end's module not installed
60
61
self ._engine = sqlalchemy .create_engine (url , ** kwargs ).execution_options (autocommit = False )
61
62
63
+ # Create a variable to hold the session. If None, autocommit is on.
64
+ self ._Session = sqlalchemy .orm .session .sessionmaker (bind = self ._engine )
65
+ self ._session = None
66
+ self ._in_transaction = False
67
+
62
68
# Listener for connections
63
69
def connect (dbapi_connection , connection_record ):
64
70
@@ -90,9 +96,8 @@ def connect(dbapi_connection, connection_record):
90
96
self ._logger .disabled = disabled
91
97
92
98
def __del__ (self ):
93
- """Close database connection."""
94
- if hasattr (self , "_connection" ):
95
- self ._connection .close ()
99
+ """Close database session and connection."""
100
+ self ._close_session ()
96
101
97
102
@_enable_logging
98
103
def execute (self , sql , * args , ** kwargs ):
@@ -125,6 +130,13 @@ def execute(self, sql, *args, **kwargs):
125
130
if token .ttype in [sqlparse .tokens .Keyword .DDL , sqlparse .tokens .Keyword .DML ]:
126
131
command = token .value .upper ()
127
132
break
133
+
134
+ # Begin a new session, if transaction started by caller (not using autocommit)
135
+ elif token .value .upper () in ["BEGIN" , "START" ]:
136
+ if self ._in_transaction :
137
+ raise RuntimeError ("transaction already open" )
138
+
139
+ self ._in_transaction = True
128
140
else :
129
141
command = None
130
142
@@ -272,6 +284,10 @@ def execute(self, sql, *args, **kwargs):
272
284
statement = "" .join ([str (token ) for token in tokens ])
273
285
274
286
# Connect to database (for transactions' sake)
287
+ if self ._session is None :
288
+ self ._session = self ._Session ()
289
+
290
+ # Set up a Flask app teardown function to close session at teardown
275
291
try :
276
292
277
293
# Infer whether Flask is installed
@@ -280,29 +296,17 @@ def execute(self, sql, *args, **kwargs):
280
296
# Infer whether app is defined
281
297
assert flask .current_app
282
298
283
- # If no connection for app's current request yet
284
- if not hasattr (flask .g , "_connection" ):
299
+ # Disconnect later - but only once
300
+ if not hasattr (self , "_teardown_appcontext_added" ):
301
+ self ._teardown_appcontext_added = True
285
302
286
- # Connect now
287
- flask .g ._connection = self ._engine .connect ()
288
-
289
- # Disconnect later
290
303
@flask .current_app .teardown_appcontext
291
304
def shutdown_session (exception = None ):
292
- if hasattr (flask .g , "_connection" ):
293
- flask .g ._connection .close ()
294
-
295
- # Use this connection
296
- connection = flask .g ._connection
305
+ """Close any existing session on app context teardown."""
306
+ self ._close_session ()
297
307
298
308
except (ModuleNotFoundError , AssertionError ):
299
-
300
- # If no connection yet
301
- if not hasattr (self , "_connection" ):
302
- self ._connection = self ._engine .connect ()
303
-
304
- # Use this connection
305
- connection = self ._connection
309
+ pass
306
310
307
311
# Catch SQLAlchemy warnings
308
312
with warnings .catch_warnings ():
@@ -316,8 +320,15 @@ def shutdown_session(exception=None):
316
320
# Join tokens into statement, abbreviating binary data as <class 'bytes'>
317
321
_statement = "" .join ([str (bytes ) if token .ttype == sqlparse .tokens .Other else str (token ) for token in tokens ])
318
322
323
+ # If COMMIT or ROLLBACK, turn on autocommit mode
324
+ if command in ["COMMIT" , "ROLLBACK" ] and "TO" not in (token .value for token in tokens ):
325
+ if not self ._in_transaction :
326
+ raise RuntimeError ("transactions must be initiated with BEGIN or START TRANSACTION" )
327
+
328
+ self ._in_transaction = False
329
+
319
330
# Execute statement
320
- result = connection .execute (sqlalchemy .text (statement ))
331
+ result = self . _session .execute (sqlalchemy .text (statement ))
321
332
322
333
# Return value
323
334
ret = True
@@ -346,7 +357,7 @@ def shutdown_session(exception=None):
346
357
elif command == "INSERT" :
347
358
if self ._engine .url .get_backend_name () in ["postgres" , "postgresql" ]:
348
359
try :
349
- result = connection .execute ("SELECT LASTVAL()" )
360
+ result = self . _session .execute ("SELECT LASTVAL()" )
350
361
ret = result .first ()[0 ]
351
362
except sqlalchemy .exc .OperationalError : # If lastval is not yet defined in this session
352
363
ret = None
@@ -357,6 +368,10 @@ def shutdown_session(exception=None):
357
368
elif command in ["DELETE" , "UPDATE" ]:
358
369
ret = result .rowcount
359
370
371
+ # If autocommit is on, commit
372
+ if not self ._in_transaction :
373
+ self ._session .commit ()
374
+
360
375
# If constraint violated, return None
361
376
except sqlalchemy .exc .IntegrityError as e :
362
377
self ._logger .debug (termcolor .colored (statement , "yellow" ))
@@ -376,6 +391,14 @@ def shutdown_session(exception=None):
376
391
self ._logger .debug (termcolor .colored (_statement , "green" ))
377
392
return ret
378
393
394
+ def _close_session (self ):
395
+ """Closes any existing session and resets instance variables."""
396
+ if self ._session is not None :
397
+ self ._session .close ()
398
+
399
+ self ._session = None
400
+ self ._in_transaction = False
401
+
379
402
def _escape (self , value ):
380
403
"""
381
404
Escapes value using engine's conversion function.
0 commit comments