5
5
import os
6
6
import re
7
7
import sqlalchemy
8
+ import sqlite3
8
9
import sqlparse
9
10
import sys
10
11
import termcolor
@@ -32,12 +33,25 @@ def __init__(self, url, **kwargs):
32
33
if not os .path .isfile (matches .group (1 )):
33
34
raise RuntimeError ("not a file: {}" .format (matches .group (1 )))
34
35
35
- # Create engine, raising exception if back end's module not installed
36
- self .engine = sqlalchemy .create_engine (url , ** kwargs )
36
+ # Remember foreign_keys and remove it from kwargs
37
+ foreign_keys = kwargs .pop ("foreign_keys" , False )
38
+
39
+ # Create engine, raising exception if back end's module not installed
40
+ self .engine = sqlalchemy .create_engine (url , ** kwargs )
41
+
42
+ # Enable foreign key constraints
43
+ if foreign_keys :
44
+ sqlalchemy .event .listen (self .engine , "connect" , _connect )
45
+ else :
46
+
47
+ # Create engine, raising exception if back end's module not installed
48
+ self .engine = sqlalchemy .create_engine (url , ** kwargs )
49
+
37
50
38
51
# Log statements to standard error
39
52
logging .basicConfig (level = logging .DEBUG )
40
53
self .logger = logging .getLogger ("cs50" )
54
+ disabled = self .logger .disabled
41
55
42
56
# Test database
43
57
try :
@@ -48,7 +62,7 @@ def __init__(self, url, **kwargs):
48
62
e .__cause__ = None
49
63
raise e
50
64
else :
51
- self .logger .disabled = False
65
+ self .logger .disabled = disabled
52
66
53
67
def _parse (self , e ):
54
68
"""Parses an exception, returns its message."""
@@ -133,6 +147,8 @@ def process(value):
133
147
return process (value )
134
148
135
149
# Allow only one statement at a time
150
+ # SQLite does not support executing many statements
151
+ # https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
136
152
if len (sqlparse .split (text )) > 1 :
137
153
raise RuntimeError ("too many statements at once" )
138
154
@@ -211,3 +227,16 @@ def process(value):
211
227
else :
212
228
self .logger .debug (termcolor .colored (log , "green" ))
213
229
return ret
230
+
231
+
232
+ # http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support
233
+ def _connect (dbapi_connection , connection_record ):
234
+ """Enables foreign key support."""
235
+
236
+ # Ensure backend is sqlite
237
+ if type (dbapi_connection ) is sqlite3 .Connection :
238
+ cursor = dbapi_connection .cursor ()
239
+
240
+ # Respect foreign key constraints by default
241
+ cursor .execute ("PRAGMA foreign_keys=ON" )
242
+ cursor .close ()
0 commit comments