Skip to content

Commit 1b671e2

Browse files
author
Kareem Zidane
committed
refactor, fix scoped session
1 parent 599d968 commit 1b671e2

File tree

8 files changed

+684
-638
lines changed

8 files changed

+684
-638
lines changed

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@
1616
package_dir={"": "src"},
1717
packages=["cs50"],
1818
url="https://github.com/cs50/python-cs50",
19-
version="6.0.4"
19+
version="7.0.0"
2020
)

src/cs50/__init__.py

+3-17
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,6 @@
1-
import logging
2-
import os
3-
import sys
1+
from ._logger import _setup_logger
2+
_setup_logger()
43

5-
6-
# Disable cs50 logger by default
7-
logging.getLogger("cs50").disabled = True
8-
9-
# Import cs50_*
10-
from .cs50 import get_char, get_float, get_int, get_string
11-
try:
12-
from .cs50 import get_long
13-
except ImportError:
14-
pass
15-
16-
# Hook into flask importing
4+
from .cs50 import get_float, get_int, get_string
175
from . import flask
18-
19-
# Wrap SQLAlchemy
206
from .sql import SQL

src/cs50/_logger.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import logging
2+
import os.path
3+
import re
4+
import sys
5+
import traceback
6+
7+
import termcolor
8+
9+
10+
def _setup_logger():
11+
_logger = logging.getLogger("cs50")
12+
_logger.disabled = True
13+
_logger.setLevel(logging.DEBUG)
14+
15+
# Log messages once
16+
_logger.propagate = False
17+
18+
handler = logging.StreamHandler()
19+
handler.setLevel(logging.DEBUG)
20+
21+
formatter = logging.Formatter("%(levelname)s: %(message)s")
22+
formatter.formatException = lambda exc_info: _formatException(*exc_info)
23+
handler.setFormatter(formatter)
24+
_logger.addHandler(handler)
25+
26+
27+
def _formatException(type, value, tb):
28+
"""
29+
Format traceback, darkening entries from global site-packages directories
30+
and user-specific site-packages directory.
31+
https://stackoverflow.com/a/46071447/5156190
32+
"""
33+
34+
# Absolute paths to site-packages
35+
packages = tuple(os.path.join(os.path.abspath(p), "") for p in sys.path[1:])
36+
37+
# Highlight lines not referring to files in site-packages
38+
lines = []
39+
for line in traceback.format_exception(type, value, tb):
40+
matches = re.search(r"^ File \"([^\"]+)\", line \d+, in .+", line)
41+
if matches and matches.group(1).startswith(packages):
42+
lines += line
43+
else:
44+
matches = re.search(r"^(\s*)(.*?)(\s*)$", line, re.DOTALL)
45+
lines.append(matches.group(1) + termcolor.colored(matches.group(2), "yellow") + matches.group(3))
46+
return "".join(lines).rstrip()
47+
48+

src/cs50/_session.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
3+
import sqlalchemy
4+
import sqlalchemy.orm
5+
import sqlite3
6+
7+
class Session:
8+
def __init__(self, url, **engine_kwargs):
9+
self._url = url
10+
if _is_sqlite_url(self._url):
11+
_assert_sqlite_file_exists(self._url)
12+
13+
self._engine = _create_engine(self._url, **engine_kwargs)
14+
self._is_postgres = self._engine.url.get_backend_name() in {"postgres", "postgresql"}
15+
_setup_on_connect(self._engine)
16+
self._session = _create_scoped_session(self._engine)
17+
18+
19+
def is_postgres(self):
20+
return self._is_postgres
21+
22+
23+
def execute(self, statement):
24+
return self._session.execute(sqlalchemy.text(str(statement)))
25+
26+
27+
def __getattr__(self, attr):
28+
return getattr(self._session, attr)
29+
30+
31+
def _is_sqlite_url(url):
32+
return url.startswith("sqlite:///")
33+
34+
35+
def _assert_sqlite_file_exists(url):
36+
path = url[len("sqlite:///"):]
37+
if not os.path.exists(path):
38+
raise RuntimeError(f"does not exist: {path}")
39+
if not os.path.isfile(path):
40+
raise RuntimeError(f"not a file: {path}")
41+
42+
43+
def _create_engine(url, **kwargs):
44+
try:
45+
engine = sqlalchemy.create_engine(url, **kwargs)
46+
except sqlalchemy.exc.ArgumentError:
47+
raise RuntimeError(f"invalid URL: {url}") from None
48+
49+
engine.execution_options(autocommit=False)
50+
return engine
51+
52+
53+
def _setup_on_connect(engine):
54+
def connect(dbapi_connection, _):
55+
_disable_auto_begin_commit(dbapi_connection)
56+
if _is_sqlite_connection(dbapi_connection):
57+
_enable_sqlite_foreign_key_constraints(dbapi_connection)
58+
59+
sqlalchemy.event.listen(engine, "connect", connect)
60+
61+
62+
def _create_scoped_session(engine):
63+
session_factory = sqlalchemy.orm.sessionmaker(bind=engine)
64+
return sqlalchemy.orm.scoping.scoped_session(session_factory)
65+
66+
67+
def _disable_auto_begin_commit(dbapi_connection):
68+
# Disable underlying API's own emitting of BEGIN and COMMIT so we can ourselves
69+
# https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
70+
dbapi_connection.isolation_level = None
71+
72+
73+
def _is_sqlite_connection(dbapi_connection):
74+
return isinstance(dbapi_connection, sqlite3.Connection)
75+
76+
77+
def _enable_sqlite_foreign_key_constraints(dbapi_connection):
78+
cursor = dbapi_connection.cursor()
79+
cursor.execute("PRAGMA foreign_keys=ON")
80+
cursor.close()

0 commit comments

Comments
 (0)