Skip to content

Commit bdb8128

Browse files
author
Kareem Zidane
authored
Merge pull request #50 from cs50/foreign-key-constraint
added foreign key constraint support to SQLite
2 parents f19934d + afb6cba commit bdb8128

File tree

4 files changed

+45
-8
lines changed

4 files changed

+45
-8
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ install:
1414
before_script:
1515
- mysql -e 'CREATE DATABASE IF NOT EXISTS test;'
1616
- psql -c 'create database test;' -U postgres
17-
- touch test.db
17+
- touch test.db test1.db
1818
script: python tests/sql.py
1919
after_script: rm -f test.db
2020
jobs:

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="2.4.0"
19+
version="2.4.1"
2020
)

src/cs50/sql.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import re
77
import sqlalchemy
8+
import sqlite3
89
import sqlparse
910
import sys
1011
import termcolor
@@ -32,12 +33,25 @@ def __init__(self, url, **kwargs):
3233
if not os.path.isfile(matches.group(1)):
3334
raise RuntimeError("not a file: {}".format(matches.group(1)))
3435

35-
# Create engine, raising exception if back end's module not installed
36-
self.engine = sqlalchemy.create_engine(url, **kwargs)
36+
# Remember foreign_keys and remove it from kwargs
37+
foreign_keys = kwargs.pop("foreign_keys", False)
38+
39+
# Create engine, raising exception if back end's module not installed
40+
self.engine = sqlalchemy.create_engine(url, **kwargs)
41+
42+
# Enable foreign key constraints
43+
if foreign_keys:
44+
sqlalchemy.event.listen(self.engine, "connect", _connect)
45+
else:
46+
47+
# Create engine, raising exception if back end's module not installed
48+
self.engine = sqlalchemy.create_engine(url, **kwargs)
49+
3750

3851
# Log statements to standard error
3952
logging.basicConfig(level=logging.DEBUG)
4053
self.logger = logging.getLogger("cs50")
54+
disabled = self.logger.disabled
4155

4256
# Test database
4357
try:
@@ -48,7 +62,7 @@ def __init__(self, url, **kwargs):
4862
e.__cause__ = None
4963
raise e
5064
else:
51-
self.logger.disabled = False
65+
self.logger.disabled = disabled
5266

5367
def _parse(self, e):
5468
"""Parses an exception, returns its message."""
@@ -133,6 +147,8 @@ def process(value):
133147
return process(value)
134148

135149
# Allow only one statement at a time
150+
# SQLite does not support executing many statements
151+
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.execute
136152
if len(sqlparse.split(text)) > 1:
137153
raise RuntimeError("too many statements at once")
138154

@@ -211,3 +227,16 @@ def process(value):
211227
else:
212228
self.logger.debug(termcolor.colored(log, "green"))
213229
return ret
230+
231+
232+
# http://docs.sqlalchemy.org/en/latest/dialects/sqlite.html#foreign-key-support
233+
def _connect(dbapi_connection, connection_record):
234+
"""Enables foreign key support."""
235+
236+
# Ensure backend is sqlite
237+
if type(dbapi_connection) is sqlite3.Connection:
238+
cursor = dbapi_connection.cursor()
239+
240+
# Respect foreign key constraints by default
241+
cursor.execute("PRAGMA foreign_keys=ON")
242+
cursor.close()

tests/sql.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -107,18 +107,26 @@ class SQLiteTests(SQLTests):
107107
@classmethod
108108
def setUpClass(self):
109109
self.db = SQL("sqlite:///test.db")
110+
self.db1 = SQL("sqlite:///test1.db", foreign_keys=True)
110111

111112
def setUp(self):
112113
self.db.execute("CREATE TABLE cs50(id INTEGER PRIMARY KEY, val TEXT)")
113114

114-
def multi_inserts_enabled(self):
115-
return False
115+
def test_foreign_key_support(self):
116+
self.db.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
117+
self.db.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
118+
self.assertEqual(self.db.execute("INSERT INTO bar VALUES(50)"), 1)
119+
120+
self.db1.execute("CREATE TABLE foo(id INTEGER PRIMARY KEY)")
121+
self.db1.execute("CREATE TABLE bar(foo_id INTEGER, FOREIGN KEY (foo_id) REFERENCES foo(id))")
122+
self.assertEqual(self.db1.execute("INSERT INTO bar VALUES(50)"), None)
116123

117124
if __name__ == "__main__":
118125
suite = unittest.TestSuite([
119126
unittest.TestLoader().loadTestsFromTestCase(SQLiteTests),
120127
unittest.TestLoader().loadTestsFromTestCase(MySQLTests),
121128
unittest.TestLoader().loadTestsFromTestCase(PostgresTests)
122129
])
123-
logging.getLogger("cs50.sql").disabled = True
130+
131+
logging.getLogger("cs50").disabled = True
124132
sys.exit(not unittest.TextTestRunner(verbosity=2).run(suite).wasSuccessful())

0 commit comments

Comments
 (0)