Skip to content

Commit d7d073c

Browse files
committed
Introduce private sqlalchemy 2.0 compatibility helpers
1 parent be4b7b9 commit d7d073c

File tree

3 files changed

+94
-10
lines changed

3 files changed

+94
-10
lines changed

conftest.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import pytest
55
import sqlalchemy as sa
6+
import sqlalchemy.event
7+
import sqlalchemy.exc
68
from sqlalchemy import create_engine
7-
from sqlalchemy.ext.declarative import declarative_base, synonym_for
89
from sqlalchemy.ext.hybrid import hybrid_property
910
from sqlalchemy.orm import sessionmaker
1011
from sqlalchemy.orm.session import close_all_sessions
@@ -15,6 +16,11 @@
1516
i18n,
1617
InstrumentedList
1718
)
19+
from sqlalchemy_utils.compat import (
20+
_declarative_base,
21+
_select_args,
22+
_synonym_for
23+
)
1824
from sqlalchemy_utils.functions.orm import _get_class_registry
1925
from sqlalchemy_utils.types.pg_composite import remove_composite_listeners
2026

@@ -148,7 +154,7 @@ def connection(engine):
148154

149155
@pytest.fixture
150156
def Base():
151-
return declarative_base()
157+
return _declarative_base()
152158

153159

154160
@pytest.fixture
@@ -185,7 +191,7 @@ def articles_count(self):
185191
def articles_count(cls):
186192
Article = _get_class_registry(Base)['Article']
187193
return (
188-
sa.select([sa.func.count(Article.id)])
194+
sa.select(*_select_args(sa.func.count(Article.id)))
189195
.where(Article.category_id == cls.id)
190196
.correlate(Article.__table__)
191197
.label('article_count')
@@ -195,7 +201,7 @@ def articles_count(cls):
195201
def name_alias(self):
196202
return self.name
197203

198-
@synonym_for('name')
204+
@_synonym_for('name')
199205
@property
200206
def name_synonym(self):
201207
return self.name
@@ -229,15 +235,22 @@ def init_models(User, Category, Article):
229235
@pytest.fixture
230236
def session(request, engine, connection, Base, init_models):
231237
sa.orm.configure_mappers()
232-
Base.metadata.create_all(connection)
238+
with connection.begin():
239+
Base.metadata.create_all(connection)
233240
Session = sessionmaker(bind=connection)
234-
session = Session()
241+
try:
242+
# Enable sqlalchemy 2.0 behavior.
243+
session = Session(future=True)
244+
except TypeError:
245+
# sqlalchemy 1.3
246+
session = Session()
235247
i18n.get_locale = get_locale
236248

237249
def teardown():
238250
aggregates.manager.reset()
239251
close_all_sessions()
240-
Base.metadata.drop_all(connection)
252+
with connection.begin():
253+
Base.metadata.drop_all(connection)
241254
remove_composite_listeners()
242255
connection.close()
243256
engine.dispose()

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def get_version():
7575
platforms='any',
7676
install_requires=[
7777
'SQLAlchemy>=1.0',
78+
"importlib_metadata ; python_version<'3.8'",
7879
],
7980
extras_require=extras_require,
8081
python_requires='~=3.6',

sqlalchemy_utils/compat.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,75 @@
1-
def get_scalar_subquery(query):
2-
try:
1+
import sys
2+
3+
if sys.version_info >= (3, 8):
4+
from importlib.metadata import metadata
5+
else:
6+
from importlib_metadata import metadata
7+
8+
9+
_sqlalchemy_version = tuple(
10+
[int(i) for i in metadata("sqlalchemy")["Version"].split(".")[:2]]
11+
)
12+
13+
14+
# In sqlalchemy 2.0, some functions moved to sqlalchemy.orm.
15+
# In sqlalchemy 1.3, they are only available in .ext.declarative.
16+
# In sqlalchemy 1.4, they are available in both places.
17+
#
18+
# WARNING
19+
# -------
20+
#
21+
# These imports are for internal, private compatibility.
22+
# They are not supported and may change or move at any time.
23+
# Do not import these in your own code.
24+
#
25+
26+
if _sqlalchemy_version >= (1, 4):
27+
from sqlalchemy.orm import declarative_base as _declarative_base
28+
from sqlalchemy.orm import synonym_for as _synonym_for
29+
else:
30+
from sqlalchemy.ext.declarative import declarative_base as _declarative_base
31+
from sqlalchemy.ext.declarative import synonym_for as _synonym_for
32+
33+
34+
# scalar subqueries
35+
if _sqlalchemy_version >= (1, 4):
36+
def get_scalar_subquery(query):
337
return query.scalar_subquery()
4-
except AttributeError: # SQLAlchemy <1.4
38+
else:
39+
def get_scalar_subquery(query):
540
return query.as_scalar()
41+
42+
43+
# In sqlalchemy 2.0, select() columns are positional.
44+
# In sqlalchemy 1.3, select() columns must be wrapped in a list.
45+
#
46+
# _select_args() is designed so its return value can be unpacked:
47+
#
48+
# select(*_select_args(1, 2))
49+
#
50+
# When sqlalchemy 1.3 support is dropped, remove the call to _select_args()
51+
# and keep the arguments the same:
52+
#
53+
# select(1, 2)
54+
#
55+
# WARNING
56+
# -------
57+
#
58+
# _select_args() is a private, internal function.
59+
# It is not supported and may change or move at any time.
60+
# Do not import this in your own code.
61+
#
62+
if _sqlalchemy_version >= (1, 4):
63+
def _select_args(*args):
64+
return args
65+
else:
66+
def _select_args(*args):
67+
return [args]
68+
69+
70+
__all__ = (
71+
"_declarative_base",
72+
"get_scalar_subquery",
73+
"_select_args",
74+
"_synonym_for",
75+
)

0 commit comments

Comments
 (0)