Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: upgrade FAB, sqlalchemy and flask-sqlalchemy #29094

Draft
wants to merge 26 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dependencies = [
"cryptography>=42.0.4, <43.0.0",
"deprecation>=2.1.0, <2.2.0",
"flask>=2.2.5, <3.0.0",
"flask-appbuilder>=4.5.0, <5.0.0",
"flask-appbuilder==5.0.0a2",
"flask-caching>=2.1.0, <3",
"flask-compress>=1.13, <2.0",
"flask-talisman>=1.0.0, <2.0",
Expand All @@ -67,7 +67,7 @@ dependencies = [
"nh3>=0.2.11, <0.3",
"numpy==1.23.5",
"packaging",
"pandas[performance]>=2.0.3, <2.1",
"pandas[performance]>=2.0.3, <2.3",
"parsedatetime",
"paramiko>=3.4.0",
"pgsanity",
Expand All @@ -86,7 +86,7 @@ dependencies = [
"sshtunnel>=0.4.0, <0.5",
"simplejson>=3.15.0",
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy>=2",
"sqlalchemy-utils>=0.38.3, <0.39",
"sqlglot>=23.0.2,<24",
"sqlparse>=0.5.0",
Expand Down
15 changes: 7 additions & 8 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ flask==2.3.3
# flask-session
# flask-sqlalchemy
# flask-wtf
flask-appbuilder==4.5.0
flask-appbuilder==5.0.0a2
# via apache-superset
flask-babel==2.0.0
# via flask-appbuilder
Expand All @@ -125,7 +125,7 @@ flask-migrate==3.1.0
# via apache-superset
flask-session==0.8.0
# via apache-superset
flask-sqlalchemy==2.5.1
flask-sqlalchemy==3.1.1
# via
# flask-appbuilder
# flask-migrate
Expand All @@ -144,9 +144,7 @@ geopy==2.4.1
google-auth==2.29.0
# via shillelagh
greenlet==3.0.3
# via
# shillelagh
# sqlalchemy
# via shillelagh
gunicorn==22.0.0
# via apache-superset
hashids==1.3.1
Expand Down Expand Up @@ -201,7 +199,7 @@ marshmallow==3.21.2
# via
# flask-appbuilder
# marshmallow-sqlalchemy
marshmallow-sqlalchemy==0.28.2
marshmallow-sqlalchemy==0.30.0
# via flask-appbuilder
mdurl==0.1.2
# via markdown-it-py
Expand Down Expand Up @@ -237,7 +235,7 @@ packaging==23.2
# marshmallow
# marshmallow-sqlalchemy
# shillelagh
pandas[performance]==2.0.3
pandas[performance]==2.2.2
# via apache-superset
paramiko==3.4.0
# via
Expand Down Expand Up @@ -331,7 +329,7 @@ six==1.16.0
# wtforms-json
slack-sdk==3.27.2
# via apache-superset
sqlalchemy==1.4.52
sqlalchemy==2.0.30
# via
# alembic
# apache-superset
Expand Down Expand Up @@ -359,6 +357,7 @@ typing-extensions==4.12.0
# flask-limiter
# limits
# shillelagh
# sqlalchemy
tzdata==2024.1
# via
# celery
Expand Down
15 changes: 2 additions & 13 deletions requirements/development.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
# via
# -r requirements/base.in
# -r requirements/development.in
appnope==0.1.4
# via ipython
astroid==3.1.0
# via pylint
boto3==1.34.112
Expand Down Expand Up @@ -236,14 +234,8 @@ thrift==0.16.0
# thrift-sasl
thrift-sasl==0.4.3
# via
# build
# coverage
# pip-tools
# pylint
# pyproject-api
# pyproject-hooks
# pytest
# tox
# apache-superset
# pyhive
tomlkit==0.12.5
# via pylint
toposort==1.10
Expand All @@ -254,9 +246,6 @@ tqdm==4.66.4
# via
# cmdstanpy
# prophet
traitlets==5.14.3
# via
# matplotlib-inline
trino==0.328.0
# via apache-superset
tzlocal==5.2
Expand Down
2 changes: 1 addition & 1 deletion scripts/tests/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -138,5 +138,5 @@ fi

if [ $RUN_TESTS -eq 1 ]
then
pytest -vv --durations=0 "${TEST_MODULE}"
pytest -vv --durations=0 --maxfail=1 "${TEST_MODULE}"
fi
2 changes: 1 addition & 1 deletion superset/cli/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,4 @@ def load_test_users_run() -> None:
sm.find_role(role),
password="general",
)
sm.get_session.commit()
sm.session.commit()
2 changes: 1 addition & 1 deletion superset/commands/dashboard/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _import(configs: dict[str, Any], overwrite: bool = False) -> None:

# store the existing relationship between dashboards and charts
existing_relationships = db.session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
select(dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id)
).fetchall()

# import dashboards
Expand Down
10 changes: 5 additions & 5 deletions superset/commands/dataset/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from flask import current_app
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.sql.visitors import VisitableType
from sqlalchemy.sql.visitors import Visitable

from superset import db, security_manager
from superset.commands.dataset.exceptions import DatasetForbiddenDataURI
Expand Down Expand Up @@ -59,7 +59,7 @@
}


def get_sqla_type(native_type: str) -> VisitableType:
def get_sqla_type(native_type: str) -> Visitable:
if native_type.upper() in type_map:
return type_map[native_type.upper()]

Expand All @@ -72,7 +72,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
)


def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]:
def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, Visitable]:
return {
column.column_name: get_sqla_type(column.type)
for column in dataset.columns
Expand Down Expand Up @@ -179,8 +179,8 @@ def import_dataset(
load_data(data_uri, dataset, dataset.database)

if user := get_user():
dataset.owners.append(user)

if user not in dataset.owners:
dataset.owners.append(user)
return dataset


Expand Down
2 changes: 1 addition & 1 deletion superset/commands/importers/v1/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _import( # pylint: disable=too-many-locals, too-many-branches

# store the existing relationship between dashboards and charts
existing_relationships = db.session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
select(dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id)
).fetchall()

# import dashboards
Expand Down
8 changes: 4 additions & 4 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Callable, cast
from typing import Any, Callable, cast, List

import dateutil.parser
import numpy as np
Expand Down Expand Up @@ -239,7 +239,7 @@ def is_virtual(self) -> bool:
return self.kind == DatasourceKind.VIRTUAL

@declared_attr
def slices(self) -> RelationshipProperty:
def slices(self) -> Mapped[List[Slice]]:
return relationship(
"Slice",
overlaps="table",
Expand Down Expand Up @@ -1148,7 +1148,7 @@ class SqlaTable(
database_id = Column(Integer, ForeignKey("dbs.id"), nullable=False)
fetch_values_predicate = Column(Text)
owners = relationship(owner_class, secondary=sqlatable_user, backref="tables")
database: Database = relationship(
database: Mapped[Database] = relationship(
"Database",
backref=backref("tables", cascade="all, delete-orphan"),
foreign_keys=[database_id],
Expand Down Expand Up @@ -1551,7 +1551,7 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
try:
# probe adhoc column type
tbl, _ = self.get_from_clause(template_processor)
qry = sa.select([sqla_column]).limit(1).select_from(tbl)
qry = sa.select(sqla_column).limit(1).select_from(tbl)
sql = self.database.compile_sqla_query(qry)
col_desc = get_columns_description(
self.database,
Expand Down
8 changes: 4 additions & 4 deletions superset/daos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def find_by_id(
"""
query = db.session.query(cls.model_cls)
if cls.base_filter and not skip_base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
data_model = SQLAInterface(cls.model_cls)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
Expand All @@ -90,7 +90,7 @@ def find_by_ids(
return []
query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
if cls.base_filter and not skip_base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
data_model = SQLAInterface(cls.model_cls)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
Expand All @@ -103,7 +103,7 @@ def find_all(cls) -> list[T]:
"""
query = db.session.query(cls.model_cls)
if cls.base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
data_model = SQLAInterface(cls.model_cls)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
Expand All @@ -116,7 +116,7 @@ def find_one_or_none(cls, **filter_by: Any) -> T | None:
"""
query = db.session.query(cls.model_cls)
if cls.base_filter:
data_model = SQLAInterface(cls.model_cls, db.session)
data_model = SQLAInterface(cls.model_cls)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
Expand Down
2 changes: 1 addition & 1 deletion superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_by_id_or_slug(cls, id_or_slug: int | str) -> Dashboard:
.outerjoin(Dashboard.roles)
)
# Apply dashboard base filters
query = cls.base_filter("id", SQLAInterface(Dashboard, db.session)).apply(
query = cls.base_filter("id", SQLAInterface(Dashboard)).apply(
query, None
)
dashboard = query.one_or_none()
Expand Down
2 changes: 1 addition & 1 deletion superset/daos/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def update_saved_query_exec_info(query_id: int) -> None:
:param query_id: The query id
:return:
"""
query = db.session.query(Query).get(query_id)
query = db.session.get(Query, query_id)
related_saved_queries = (
db.session.query(SavedQuery)
.filter(SavedQuery.database == query.database)
Expand Down
4 changes: 0 additions & 4 deletions superset/daos/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def create_custom_tagged_objects(
for name in clean_tag_names:
type_ = TagType.custom
tag = TagDAO.get_by_name(name, type_)
tagged_objects.append(
TaggedObject(object_id=object_id, object_type=object_type, tag=tag)
)

# Check if the association already exists
existing_tagged_object = (
db.session.query(TaggedObject)
Expand Down
4 changes: 2 additions & 2 deletions superset/databases/ssh_tunnel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sqlalchemy as sa
from flask import current_app
from flask_appbuilder import Model
from sqlalchemy.orm import backref, relationship
from sqlalchemy.orm import backref, relationship, Mapped
from sqlalchemy.types import Text

from superset.constants import PASSWORD_MASK
Expand All @@ -46,7 +46,7 @@ class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
database_id = sa.Column(
sa.Integer, sa.ForeignKey("dbs.id"), nullable=False, unique=True
)
database: Database = relationship(
database: Mapped[Database] = relationship(
"Database",
backref=backref("ssh_tunnels", uselist=False, cascade="all, delete-orphan"),
foreign_keys=[database_id],
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,7 +1660,7 @@ def select_star( # pylint: disable=too-many-arguments,too-many-locals
else quote(table.table)
)

qry = select(fields).select_from(text(full_table_name))
qry = select(*fields).select_from(text(full_table_name))

if limit and cls.allow_limit_clause:
qry = qry.limit(limit)
Expand Down
17 changes: 8 additions & 9 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Any, TYPE_CHECKING

from flask_babel import gettext as __
from sqlalchemy import text
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
from sqlalchemy.engine.reflection import Inspector
Expand Down Expand Up @@ -375,15 +376,13 @@ def get_catalog_names(

In Postgres, a catalog is called a "database".
"""
return {
catalog
for (catalog,) in inspector.bind.execute(
"""
SELECT datname FROM pg_database
WHERE datistemplate = false;
"""
)
}
catalogs = set()
with inspector.bind.connect() as connection:
for (catalog,) in connection.execute(
text("SELECT datname FROM pg_database WHERE datistemplate = false;")
):
catalogs.add(catalog)
return catalogs

@classmethod
def get_table_names(
Expand Down
4 changes: 2 additions & 2 deletions superset/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@

import celery
from flask import Flask
from flask_appbuilder import AppBuilder, SQLA
from flask_appbuilder import AppBuilder
from flask_appbuilder.extensions import db
from flask_caching.backends.base import BaseCache
from flask_migrate import Migrate
from flask_talisman import Talisman
Expand Down Expand Up @@ -122,7 +123,6 @@ def init_app(self, app: Flask) -> None:
cache_manager = CacheManager()
celery_app = celery.Celery()
csrf = CSRFProtect()
db = SQLA() # pylint: disable=disallowed-name
_event_logger: dict[str, Any] = {}
encrypted_field_factory = EncryptedFieldFactory()
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
Expand Down
6 changes: 3 additions & 3 deletions superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def configure_fab(self) -> None:
appbuilder.indexview = SupersetIndexView
appbuilder.base_template = "superset/base.html"
appbuilder.security_manager_class = custom_sm
appbuilder.init_app(self.superset_app, db.session)
appbuilder.init_app(self.superset_app)

def configure_url_map_converters(self) -> None:
#
Expand Down Expand Up @@ -628,8 +628,8 @@ def configure_db_encrypt(self) -> None:
def setup_db(self) -> None:
db.init_app(self.superset_app)

with self.superset_app.app_context():
pessimistic_connection_handling(db.engine)
# with self.superset_app.app_context():
# pessimistic_connection_handling(db.engine)

migrate.init_app(self.superset_app, db=db, directory=APP_DIR + "/migrations")

Expand Down
Loading
Loading