Skip to content

Commit

Permalink
chore(dao): Add explicit ON DELETE CASCADE when deleting datasets (#2…
Browse files Browse the repository at this point in the history
  • Loading branch information
john-bodley committed Jun 28, 2023
1 parent f1b003f commit 75543af
Show file tree
Hide file tree
Showing 15 changed files with 144 additions and 78 deletions.
1 change: 1 addition & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ assists people when migrating to a new version.

## Next

- [24488](https://github.com/apache/superset/pull/24488): Augments the foreign key constraints for the `sql_metrics`, `sqlatable_user`, and `table_columns` tables which reference the `tables` table to include an explicit CASCADE ON DELETE to ensure the relevant records are deleted when a dataset is deleted. Scheduled downtime may be advised.
- [24335](https://github.com/apache/superset/pull/24335): Removed deprecated API `/superset/filter/<datasource_type>/<int:datasource_id>/<column>/`
- [24185](https://github.com/apache/superset/pull/24185): `/api/v1/database/test_connection` and `api/v1/database/validate_parameters` permissions changed from `can_read` to `can_write`. Only Admin user's have access.
- [24232](https://github.com/apache/superset/pull/24232): Enables ENABLE_TEMPLATE_REMOVE_FILTERS, DRILL_TO_DETAIL, DASHBOARD_CROSS_FILTERS by default, marks VERSIONED_EXPORT and ENABLE_TEMPLATE_REMOVE_FILTERS as deprecated.
Expand Down
10 changes: 6 additions & 4 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):

__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
table: Mapped[SqlaTable] = relationship(
"SqlaTable",
back_populates="columns",
Expand Down Expand Up @@ -400,7 +400,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):

__tablename__ = "sql_metrics"
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
table: Mapped[SqlaTable] = relationship(
"SqlaTable",
back_populates="metrics",
Expand Down Expand Up @@ -470,8 +470,8 @@ def data(self) -> dict[str, Any]:
"sqlatable_user",
metadata,
Column("id", Integer, primary_key=True),
Column("user_id", Integer, ForeignKey("ab_user.id")),
Column("table_id", Integer, ForeignKey("tables.id")),
Column("user_id", Integer, ForeignKey("ab_user.id", ondelete="CASCADE")),
Column("table_id", Integer, ForeignKey("tables.id", ondelete="CASCADE")),
)


Expand Down Expand Up @@ -508,11 +508,13 @@ class SqlaTable(
TableColumn,
back_populates="table",
cascade="all, delete-orphan",
passive_deletes=True,
)
metrics: Mapped[list[SqlMetric]] = relationship(
SqlMetric,
back_populates="table",
cascade="all, delete-orphan",
passive_deletes=True,
)
metric_class = SqlMetric
column_class = TableColumn
Expand Down
26 changes: 13 additions & 13 deletions superset/daos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from sqlalchemy.exc import SQLAlchemyError

from superset import security_manager
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.daos.base import BaseDAO
from superset.extensions import db
Expand Down Expand Up @@ -361,25 +362,24 @@ def create_metric(
"""
return DatasetMetricDAO.create(properties, commit=commit)

@staticmethod
def bulk_delete(models: Optional[list[SqlaTable]], commit: bool = True) -> None:
@classmethod
def bulk_delete(
cls, models: Optional[list[SqlaTable]], commit: bool = True
) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
for model in models:
model.owners = []
db.session.merge(model)
db.session.query(SqlMetric).filter(SqlMetric.table_id.in_(item_ids)).delete(
synchronize_session="fetch"
)
db.session.query(TableColumn).filter(
TableColumn.table_id.in_(item_ids)
).delete(synchronize_session="fetch")
# bulk delete itself
try:
db.session.query(SqlaTable).filter(SqlaTable.id.in_(item_ids)).delete(
synchronize_session="fetch"
)

if models:
connection = db.session.connection()
mapper = next(iter(cls.model_cls.registry.mappers)) # type: ignore

for model in models:
security_manager.dataset_after_delete(mapper, connection, model)

if commit:
db.session.commit()
except SQLAlchemyError as ex:
Expand Down
30 changes: 2 additions & 28 deletions superset/datasets/commands/bulk_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
DatasetNotFoundError,
)
from superset.exceptions import SupersetSecurityException
from superset.extensions import db

logger = logging.getLogger(__name__)

Expand All @@ -40,35 +39,10 @@ def __init__(self, model_ids: list[int]):

def run(self) -> None:
self.validate()
if not self._models:
return None
assert self._models

try:
DatasetDAO.bulk_delete(self._models)
for model in self._models:
view_menu = (
security_manager.find_view_menu(model.get_perm()) if model else None
)

if view_menu:
permission_views = (
db.session.query(security_manager.permissionview_model)
.filter_by(view_menu=view_menu)
.all()
)

for permission_view in permission_views:
db.session.delete(permission_view)
if view_menu:
db.session.delete(view_menu)
else:
if not view_menu:
logger.error(
"Could not find the data access permission for the dataset",
exc_info=True,
)
db.session.commit()

return None
except DeleteFailedError as ex:
logger.exception(ex.exception)
raise DatasetBulkDeleteFailedError() from ex
Expand Down
18 changes: 5 additions & 13 deletions superset/datasets/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import cast, Optional
from typing import Optional

from flask_appbuilder.models.sqla import Model
from sqlalchemy.exc import SQLAlchemyError

from superset import security_manager
from superset.commands.base import BaseCommand
Expand All @@ -31,7 +30,6 @@
DatasetNotFoundError,
)
from superset.exceptions import SupersetSecurityException
from superset.extensions import db

logger = logging.getLogger(__name__)

Expand All @@ -43,19 +41,13 @@ def __init__(self, model_id: int):

def run(self) -> Model:
self.validate()
self._model = cast(SqlaTable, self._model)
assert self._model

try:
# Even though SQLAlchemy should in theory delete rows from the association
# table, sporadically Superset will error because the rows are not deleted.
# Let's do it manually here to prevent the error.
self._model.owners = []
dataset = DatasetDAO.delete(self._model, commit=False)
db.session.commit()
except (SQLAlchemyError, DAODeleteFailedError) as ex:
return DatasetDAO.delete(self._model)
except DAODeleteFailedError as ex:
logger.exception(ex)
db.session.rollback()
raise DatasetDeleteFailedError() from ex
return dataset

def validate(self) -> None:
# Validate/populate model exists
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add on delete cascade for tables references
Revision ID: 6fbe660cac39
Revises: 83e1abbe777f
Create Date: 2023-06-22 13:39:47.989373
"""
from __future__ import annotations

# revision identifiers, used by Alembic.
revision = "6fbe660cac39"
down_revision = "83e1abbe777f"

import sqlalchemy as sa
from alembic import op

from superset.utils.core import generic_find_fk_constraint_name


def migrate(ondelete: str | None) -> None:
"""
Redefine the foreign key constraints, via a successive DROP and ADD, for all tables
related to the `tables` table to include the ON DELETE construct for cascading
purposes.
:param ondelete: If set, emit ON DELETE <value> when issuing DDL for this constraint
"""

bind = op.get_bind()
insp = sa.engine.reflection.Inspector.from_engine(bind)

conv = {
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
}

for table in ("sql_metrics", "table_columns"):
with op.batch_alter_table(table, naming_convention=conv) as batch_op:
if constraint := generic_find_fk_constraint_name(
table=table,
columns={"id"},
referenced="tables",
insp=insp,
):
batch_op.drop_constraint(constraint, type_="foreignkey")

batch_op.create_foreign_key(
constraint_name=f"fk_{table}_table_id_tables",
referent_table="tables",
local_cols=["table_id"],
remote_cols=["id"],
ondelete=ondelete,
)

with op.batch_alter_table("sqlatable_user", naming_convention=conv) as batch_op:
for table, column in zip(("ab_user", "tables"), ("user_id", "table_id")):
if constraint := generic_find_fk_constraint_name(
table="sqlatable_user",
columns={"id"},
referenced=table,
insp=insp,
):
batch_op.drop_constraint(constraint, type_="foreignkey")

batch_op.create_foreign_key(
constraint_name=f"fk_sqlatable_user_{column}_{table}",
referent_table=table,
local_cols=[column],
remote_cols=["id"],
ondelete=ondelete,
)


def upgrade():
migrate(ondelete="CASCADE")


def downgrade():
migrate(ondelete=None)
21 changes: 20 additions & 1 deletion superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@
import re
import signal
import smtplib
import sqlite3
import ssl
import tempfile
import threading
import traceback
import uuid
import zlib
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from email.mime.application import MIMEApplication
Expand Down Expand Up @@ -849,6 +850,24 @@ def ping_connection(connection: Connection, branch: bool) -> None:
# restore 'close with result'
connection.should_close_with_result = save_should_close_with_result

if some_engine.dialect.name == "sqlite":

@event.listens_for(some_engine, "connect")
def set_sqlite_pragma( # pylint: disable=unused-argument
connection: sqlite3.Connection,
*args: Any,
) -> None:
r"""
Enable foreign key support for SQLite.
:param connection: The SQLite connection
:param \*args: Additional positional arguments
:see: https://docs.sqlalchemy.org/en/latest/dialects/sqlite.html
"""

with closing(connection.cursor()) as cursor:
cursor.execute("PRAGMA foreign_keys=ON")


def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many-locals
to: str,
Expand Down
2 changes: 0 additions & 2 deletions tests/integration_tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,6 @@ def test_import_chart(self):
assert chart.table == dataset

chart.owners = []
dataset.owners = []
db.session.delete(chart)
db.session.commit()
db.session.delete(dataset)
Expand Down Expand Up @@ -1577,7 +1576,6 @@ def test_import_chart_overwrite(self):
chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one()

chart.owners = []
dataset.owners = []
db.session.delete(chart)
db.session.commit()
db.session.delete(dataset)
Expand Down
1 change: 0 additions & 1 deletion tests/integration_tests/charts/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def test_import_v1_chart(self, sm_g, utils_g):
assert chart.owners == [admin]

chart.owners = []
dataset.owners = []
database.owners = []
db.session.delete(chart)
db.session.delete(dataset)
Expand Down
3 changes: 1 addition & 2 deletions tests/integration_tests/commands_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def test_import_assets(self):

dashboard.owners = []
chart.owners = []
dataset.owners = []
database.owners = []
db.session.delete(dashboard)
db.session.delete(chart)
Expand All @@ -165,6 +164,7 @@ def test_import_v1_dashboard_overwrite(self):
"charts/imported_chart.yaml": yaml.safe_dump(chart_config),
"dashboards/imported_dashboard.yaml": yaml.safe_dump(dashboard_config),
}

command = ImportAssetsCommand(contents)
command.run()
chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one()
Expand Down Expand Up @@ -193,7 +193,6 @@ def test_import_v1_dashboard_overwrite(self):
dashboard.owners = []

chart.owners = []
dataset.owners = []
database.owners = []
db.session.delete(dashboard)
db.session.delete(chart)
Expand Down
1 change: 0 additions & 1 deletion tests/integration_tests/dashboards/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ def test_import_v1_dashboard(self, sm_g, utils_g):

dashboard.owners = []
chart.owners = []
dataset.owners = []
database.owners = []
db.session.delete(dashboard)
db.session.delete(chart)
Expand Down
2 changes: 0 additions & 2 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,7 +2143,6 @@ def test_import_database(self):
assert dataset.table_name == "imported_dataset"
assert str(dataset.uuid) == dataset_config["uuid"]

dataset.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
Expand Down Expand Up @@ -2214,7 +2213,6 @@ def test_import_database_overwrite(self):
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
dataset = database.tables[0]
dataset.owners = []
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
Expand Down
Loading

0 comments on commit 75543af

Please sign in to comment.