diff --git a/pyproject.toml b/pyproject.toml index 9fd740d82ece9..efb211c0d34fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -239,8 +239,8 @@ basepython = python3.10 ignore_basepython_conflict = true commands = superset db upgrade - superset load_test_users superset init + superset load-test-users # use -s to be able to use break pointers. # no args or tests/* can be passed as an argument to run all tests pytest -s {posargs} diff --git a/scripts/permissions_cleanup.py b/scripts/permissions_cleanup.py index c80ef231b3d22..22e58f013fa3e 100644 --- a/scripts/permissions_cleanup.py +++ b/scripts/permissions_cleanup.py @@ -14,51 +14,59 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=consider-using-transaction from collections import defaultdict -from superset import db, security_manager +from superset import security_manager from superset.utils.decorators import transaction @transaction() def cleanup_permissions() -> None: # 1. Clean up duplicates. - pvms = db.session.query(security_manager.permissionview_model).all() + pvms = security_manager.get_session.query( + security_manager.permissionview_model + ).all() print(f"# of permission view menus is: {len(pvms)}") pvms_dict = defaultdict(list) for pvm in pvms: pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm) duplicates = [v for v in pvms_dict.values() if len(v) > 1] - len(duplicates) for pvm_list in duplicates: first_prm = pvm_list[0] roles = set(first_prm.role) for pvm in pvm_list[1:]: roles = roles.union(pvm.role) - db.session.delete(pvm) + security_manager.get_session.delete(pvm) first_prm.roles = list(roles) - pvms = db.session.query(security_manager.permissionview_model).all() + pvms = security_manager.get_session.query( + security_manager.permissionview_model + ).all() print(f"Stage 1: # of permission view menus is: {len(pvms)}") # 2. Clean up None permissions or view menus - pvms = db.session.query(security_manager.permissionview_model).all() + pvms = security_manager.get_session.query( + security_manager.permissionview_model + ).all() for pvm in pvms: if not (pvm.view_menu and pvm.permission): - db.session.delete(pvm) + security_manager.get_session.delete(pvm) - pvms = db.session.query(security_manager.permissionview_model).all() + pvms = security_manager.get_session.query( + security_manager.permissionview_model + ).all() print(f"Stage 2: # of permission view menus is: {len(pvms)}") # 3. Delete empty permission view menus from roles - roles = db.session.query(security_manager.role_model).all() + roles = security_manager.get_session.query(security_manager.role_model).all() for role in roles: role.permissions = [p for p in role.permissions if p] # 4. Delete empty roles from permission view menus - pvms = db.session.query(security_manager.permissionview_model).all() + pvms = security_manager.get_session.query( + security_manager.permissionview_model + ).all() for pvm in pvms: pvm.role = [r for r in pvm.role if r] diff --git a/scripts/python_tests.sh b/scripts/python_tests.sh index 443b1d5d61ddc..63ea35599c93f 100755 --- a/scripts/python_tests.sh +++ b/scripts/python_tests.sh @@ -28,9 +28,10 @@ export SUPERSET_TESTENV=true echo "Superset config module: $SUPERSET_CONFIG" superset db upgrade -superset load_test_users superset init +superset load-test-users echo "Running tests" +# pytest --durations-min=2 --maxfail=1 --cov-report= --cov=superset ./tests/integration_tests "$@" pytest --durations-min=2 --cov-report= --cov=superset ./tests/integration_tests "$@" diff --git a/superset/cli/test.py b/superset/cli/test.py index 33b777b1eff06..60ea532cbdba4 100755 --- a/superset/cli/test.py +++ b/superset/cli/test.py @@ -37,15 +37,7 @@ def load_test_users() -> None: Syncs permissions for those users/roles """ print(Fore.GREEN + "Loading a set of users for unit tests") - load_test_users_run() - -def load_test_users_run() -> None: - """ - Loads admin, alpha, and gamma user for testing purposes - - Syncs permissions for those users/roles - """ if app.config["TESTING"]: sm = security_manager diff --git a/superset/commands/chart/importers/v1/utils.py b/superset/commands/chart/importers/v1/utils.py index 39ca49a5d5ffc..35a7f6e2700f3 100644 --- a/superset/commands/chart/importers/v1/utils.py +++ b/superset/commands/chart/importers/v1/utils.py @@ -77,7 +77,7 @@ def import_chart( if chart.id is None: db.session.flush() - if user := get_user(): + if (user := get_user()) and user not in chart.owners: chart.owners.append(user) return chart diff --git a/superset/commands/chart/update.py b/superset/commands/chart/update.py index 1ea698ba0dc48..d6b212d5ce861 100644 --- a/superset/commands/chart/update.py +++ b/superset/commands/chart/update.py @@ -38,8 +38,8 @@ from superset.daos.dashboard import DashboardDAO from superset.exceptions import SupersetSecurityException from superset.models.slice import Slice -from superset.utils.decorators import on_error, transaction from superset.tags.models import ObjectType +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -62,14 +62,13 @@ def run(self) -> Model: assert self._model # Update tags - tags = self._properties.pop("tags", None) - if tags is not None: + if (tags := self._properties.pop("tags", None)) is not None: update_tags(ObjectType.chart, self._model.id, self._model.tags, tags) if self._properties.get("query_context_generation") is None: self._properties["last_saved_at"] = datetime.now() self._properties["last_saved_by"] = g.user - + return ChartDAO.update(self._model, self._properties) def validate(self) -> None: diff --git a/superset/commands/dashboard/importers/v1/utils.py b/superset/commands/dashboard/importers/v1/utils.py index f10afd12bc9ee..5e949093b8a80 100644 --- a/superset/commands/dashboard/importers/v1/utils.py +++ b/superset/commands/dashboard/importers/v1/utils.py @@ -188,7 +188,7 @@ def import_dashboard( if dashboard.id is None: db.session.flush() - if user := get_user(): + if (user := get_user()) and user not in dashboard.owners: dashboard.owners.append(user) return dashboard diff --git a/superset/commands/dashboard/permalink/create.py b/superset/commands/dashboard/permalink/create.py index f6bff344c8076..7d08f78e9a9be 100644 --- a/superset/commands/dashboard/permalink/create.py +++ b/superset/commands/dashboard/permalink/create.py @@ -19,7 +19,6 @@ from sqlalchemy.exc import SQLAlchemyError -from superset import db from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand from superset.commands.key_value.upsert import UpsertKeyValueCommand from superset.daos.dashboard import DashboardDAO @@ -78,6 +77,7 @@ def run(self) -> str: codec=self.codec, ).run() assert key.id # for type checks + return encode_permalink_key(key=key.id, salt=self.salt) def validate(self) -> None: pass diff --git a/superset/commands/dashboard/update.py b/superset/commands/dashboard/update.py index 5294d049ec714..2effd7bd2ece1 100644 --- a/superset/commands/dashboard/update.py +++ b/superset/commands/dashboard/update.py @@ -53,19 +53,16 @@ def run(self) -> Model: assert self._model # Update tags - tags = self._properties.pop("tags", None) - if tags is not None: - update_tags( - ObjectType.dashboard, self._model.id, self._model.tags, tags - ) + if (tags := self._properties.pop("tags", None)) is not None: + update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags) - dashboard = DashboardDAO.update(self._model, self._properties, commit=False) + dashboard = DashboardDAO.update(self._model, self._properties) if self._properties.get("json_metadata"): DashboardDAO.set_dash_metadata( dashboard, data=json.loads(self._properties.get("json_metadata", "{}")), ) - + return dashboard def validate(self) -> None: diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index 842f69a7ab0e8..76dd6087be58a 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -40,7 +40,7 @@ ) from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.daos.database import DatabaseDAO -from superset.daos.exceptions import DAOCreateFailedError +from superset.databases.ssh_tunnel.models import SSHTunnel from superset.exceptions import SupersetErrorsException from superset.extensions import event_logger, security_manager from superset.models.core import Database diff --git a/superset/commands/database/csv_import.py b/superset/commands/database/csv_import.py deleted file mode 100644 index 3354a81a4d806..0000000000000 --- a/superset/commands/database/csv_import.py +++ /dev/null @@ -1,194 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from functools import partial -from typing import Any, Optional, TypedDict - -import pandas as pd -from flask_babel import lazy_gettext as _ - -from superset import db -from superset.commands.base import BaseCommand -from superset.commands.database.exceptions import ( - DatabaseNotFoundError, - DatabaseSchemaUploadNotAllowed, - DatabaseUploadFailed, - DatabaseUploadSaveMetadataFailed, -) -from superset.connectors.sqla.models import SqlaTable -from superset.daos.database import DatabaseDAO -from superset.models.core import Database -from superset.sql_parse import Table -from superset.utils.core import get_user -from superset.utils.decorators import on_error, transaction -from superset.views.database.validators import schema_allows_file_upload - -logger = logging.getLogger(__name__) - -READ_CSV_CHUNK_SIZE = 1000 - - -class CSVImportOptions(TypedDict, total=False): - schema: str - delimiter: str - already_exists: str - column_data_types: dict[str, str] - column_dates: list[str] - column_labels: str - columns_read: list[str] - dataframe_index: str - day_first: bool - decimal_character: str - header_row: int - index_column: str - null_values: list[str] - overwrite_duplicates: bool - rows_to_read: int - skip_blank_lines: bool - skip_initial_space: bool - skip_rows: int - - -class CSVImportCommand(BaseCommand): - def __init__( - self, - model_id: int, - table_name: str, - file: Any, - options: CSVImportOptions, - ) -> None: - self._model_id = model_id - self._model: Optional[Database] = None - self._table_name = table_name - self._schema = options.get("schema") - self._file = file - self._options = options - - def _read_csv(self) -> pd.DataFrame: - """ - Read CSV file into a DataFrame - - :return: pandas DataFrame - :throws DatabaseUploadFailed: if there is an error reading the CSV file - """ - try: - return pd.concat( - pd.read_csv( - chunksize=READ_CSV_CHUNK_SIZE, - encoding="utf-8", - filepath_or_buffer=self._file, - header=self._options.get("header_row", 0), - index_col=self._options.get("index_column"), - dayfirst=self._options.get("day_first", False), - iterator=True, - keep_default_na=not self._options.get("null_values"), - usecols=self._options.get("columns_read") - if self._options.get("columns_read") # None if an empty list - else None, - na_values=self._options.get("null_values") - if self._options.get("null_values") # None if an empty list - else None, - nrows=self._options.get("rows_to_read"), - parse_dates=self._options.get("column_dates"), - sep=self._options.get("delimiter", ","), - skip_blank_lines=self._options.get("skip_blank_lines", False), - skipinitialspace=self._options.get("skip_initial_space", False), - skiprows=self._options.get("skip_rows", 0), - dtype=self._options.get("column_data_types") - if self._options.get("column_data_types") - else None, - ) - ) - except ( - pd.errors.ParserError, - pd.errors.EmptyDataError, - UnicodeDecodeError, - ValueError, - ) as ex: - raise DatabaseUploadFailed( - message=_("Parsing error: %(error)s", error=str(ex)) - ) from ex - except Exception as ex: - raise DatabaseUploadFailed(_("Error reading CSV file")) from ex - - def _dataframe_to_database(self, df: pd.DataFrame, database: Database) -> None: - """ - Upload DataFrame to database - - :param df: - :throws DatabaseUploadFailed: if there is an error uploading the DataFrame - """ - try: - csv_table = Table(table=self._table_name, schema=self._schema) - database.db_engine_spec.df_to_sql( - database, - csv_table, - df, - to_sql_kwargs={ - "chunksize": READ_CSV_CHUNK_SIZE, - "if_exists": self._options.get("already_exists", "fail"), - "index": self._options.get("index_column"), - "index_label": self._options.get("column_labels"), - }, - ) - except ValueError as ex: - raise DatabaseUploadFailed( - message=_( - "Table already exists. You can change your " - "'if table already exists' strategy to append or " - "replace or provide a different Table Name to use." - ) - ) from ex - except Exception as ex: - raise DatabaseUploadFailed(exception=ex) from ex - - @transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed)) - def run(self) -> None: - self.validate() - if not self._model: - return - - df = self._read_csv() - self._dataframe_to_database(df, self._model) - - sqla_table = ( - db.session.query(SqlaTable) - .filter_by( - table_name=self._table_name, - schema=self._schema, - database_id=self._model_id, - ) - .one_or_none() - ) - if not sqla_table: - sqla_table = SqlaTable( - table_name=self._table_name, - database=self._model, - database_id=self._model_id, - owners=[get_user()], - schema=self._schema, - ) - db.session.add(sqla_table) - - sqla_table.fetch_metadata() - - def validate(self) -> None: - self._model = DatabaseDAO.find_by_id(self._model_id) - if not self._model: - raise DatabaseNotFoundError() - if not schema_allows_file_upload(self._model, self._schema): - raise DatabaseSchemaUploadNotAllowed() diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 1fd7d786dcf93..b4db4162007a2 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -26,6 +26,7 @@ from superset import is_feature_enabled, security_manager from superset.commands.base import BaseCommand from superset.commands.database.exceptions import ( + DatabaseConnectionFailedError, DatabaseExistsValidationError, DatabaseInvalidError, DatabaseNotFoundError, @@ -41,7 +42,6 @@ from superset.daos.database import DatabaseDAO from superset.daos.dataset import DatasetDAO from superset.databases.ssh_tunnel.models import SSHTunnel -from superset.extensions import db from superset.models.core import Database from superset.utils.decorators import on_error, transaction @@ -78,16 +78,12 @@ def run(self) -> Model: original_database_name = self._model.database_name try: - database = DatabaseDAO.update( - self._model, - self._properties, - commit=False, - ) + database = DatabaseDAO.update(self._model, self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) self._refresh_catalogs(database, original_database_name, ssh_tunnel) except SSHTunnelError: # pylint: disable=try-except-raise - # allow exception to bubble for debugbing information + # allow exception to bubble for debugging information raise return database @@ -100,7 +96,6 @@ def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None: return None if not is_feature_enabled("SSH_TUNNELING"): - db.session.rollback() raise SSHTunnelingNotEnabledError() current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) @@ -130,13 +125,13 @@ def _get_catalog_names( This method captures a generic exception, since errors could potentially come from any of the 50+ database drivers we support. """ + try: return database.get_all_catalog_names( force=True, ssh_tunnel=ssh_tunnel, ) except Exception as ex: - db.session.rollback() raise DatabaseConnectionFailedError() from ex def _get_schema_names( @@ -151,6 +146,7 @@ def _get_schema_names( This method captures a generic exception, since errors could potentially come from any of the 50+ database drivers we support. """ + try: return database.get_all_schema_names( force=True, @@ -158,7 +154,6 @@ def _get_schema_names( ssh_tunnel=ssh_tunnel, ) except Exception as ex: - db.session.rollback() raise DatabaseConnectionFailedError() from ex def _refresh_catalogs( @@ -224,8 +219,6 @@ def _refresh_catalogs( schemas, ) - db.session.commit() - def _refresh_schemas( self, database: Database, diff --git a/superset/commands/dataset/columns/delete.py b/superset/commands/dataset/columns/delete.py index 1fb2863b1bef4..821528de74d4f 100644 --- a/superset/commands/dataset/columns/delete.py +++ b/superset/commands/dataset/columns/delete.py @@ -43,7 +43,7 @@ def __init__(self, dataset_id: int, model_id: int): def run(self) -> None: self.validate() assert self._model - DatasetColumnDAO.delete(self._model) + DatasetColumnDAO.delete([self._model]) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index 28983e74f8693..a2d81e548bfb0 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -32,7 +32,7 @@ ) from superset.daos.dataset import DatasetDAO from superset.exceptions import SupersetSecurityException -from superset.extensions import db, security_manager +from superset.extensions import security_manager from superset.sql_parse import Table from superset.utils.decorators import on_error, transaction diff --git a/superset/commands/dataset/importers/v0.py b/superset/commands/dataset/importers/v0.py index 6b6f19374233f..d6f7380cb5d1d 100644 --- a/superset/commands/dataset/importers/v0.py +++ b/superset/commands/dataset/importers/v0.py @@ -260,7 +260,6 @@ def run(self) -> None: ) dataset["database_id"] = database.id SqlaTable.import_from_dict(dataset, sync=self.sync) - db.session.commit() def validate(self) -> None: # ensure all files are YAML diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index da39be4721c0c..1c508fe2522e8 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -178,7 +178,7 @@ def import_dataset( if data_uri and (not table_exists or force_data): load_data(data_uri, dataset, dataset.database) - if user := get_user(): + if (user := get_user()) and user not in dataset.owners: dataset.owners.append(user) return dataset diff --git a/superset/commands/dataset/metrics/delete.py b/superset/commands/dataset/metrics/delete.py index e4d65236c3017..0a749295dc3d6 100644 --- a/superset/commands/dataset/metrics/delete.py +++ b/superset/commands/dataset/metrics/delete.py @@ -43,7 +43,7 @@ def __init__(self, dataset_id: int, model_id: int): def run(self) -> None: self.validate() assert self._model - DatasetMetricDAO.delete(self._model) + DatasetMetricDAO.delete([self._model]) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index 40b1bf18baf6e..14d1c5ef44707 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -21,6 +21,7 @@ from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError +from sqlalchemy.exc import SQLAlchemyError from superset import security_manager from superset.commands.base import BaseCommand, UpdateMixin @@ -60,7 +61,16 @@ def __init__( self.override_columns = override_columns self._properties["override_columns"] = override_columns - @transaction(on_error=partial(on_error, reraise=DatasetUpdateFailedError)) + @transaction( + on_error=partial( + on_error, + catches=( + SQLAlchemyError, + ValueError, + ), + reraise=DatasetUpdateFailedError, + ) + ) def run(self) -> Model: self.validate() assert self._model diff --git a/superset/commands/explore/permalink/create.py b/superset/commands/explore/permalink/create.py index 03efdc584a424..2128fa4b8c40e 100644 --- a/superset/commands/explore/permalink/create.py +++ b/superset/commands/explore/permalink/create.py @@ -20,7 +20,6 @@ from sqlalchemy.exc import SQLAlchemyError -from superset import db from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand from superset.commands.key_value.create import CreateKeyValueCommand from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError @@ -74,27 +73,7 @@ def run(self) -> str: key = command.run() if key.id is None: raise ExplorePermalinkCreateFailedError("Unexpected missing key id") - db.session.commit() return encode_permalink_key(key=key.id, salt=self.salt) - d_id, d_type = self.datasource.split("__") - datasource_id = int(d_id) - datasource_type = DatasourceType(d_type) - check_chart_access(datasource_id, self.chart_id, datasource_type) - value = { - "chartId": self.chart_id, - "datasourceId": datasource_id, - "datasourceType": datasource_type.value, - "datasource": self.datasource, - "state": self.state, - } - command = CreateKeyValueCommand( - resource=self.resource, - value=value, - codec=self.codec, - ) - key = command.run() - return encode_permalink_key(key=key.id, salt=self.salt) ->>>>>>> c01dacb71a (chore(dao): Use nested session for operations) def validate(self) -> None: pass diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py index 5f955db3bf9ee..78a2251a293af 100644 --- a/superset/commands/importers/v1/assets.py +++ b/superset/commands/importers/v1/assets.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from functools import partial from typing import Any, Optional from marshmallow import Schema @@ -44,7 +45,7 @@ from superset.migrations.shared.native_filters import migrate_dashboard from superset.models.dashboard import dashboard_slices from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema -from superset.utils.decorators import transaction +from superset.utils.decorators import on_error, transaction class ImportAssetsCommand(BaseCommand): @@ -154,14 +155,16 @@ def _import(configs: dict[str, Any]) -> None: if chart.viz_type == "filter_box": db.session.delete(chart) - @transaction() + @transaction( + on_error=partial( + on_error, + catches=(Exception,), + reraise=ImportFailedError, + ) + ) def run(self) -> None: self.validate() - - try: - self._import(self._configs) - except Exception as ex: - raise ImportFailedError() from ex + self._import(self._configs) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/key_value/delete.py b/superset/commands/key_value/delete.py index ec386675b5037..a3fdf079c73c2 100644 --- a/superset/commands/key_value/delete.py +++ b/superset/commands/key_value/delete.py @@ -53,9 +53,11 @@ def validate(self) -> None: pass def delete(self) -> bool: - filter_ = get_filter(self.resource, self.key) - if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first(): + if ( + entry := db.session.query(KeyValueEntry) + .filter_by(**get_filter(self.resource, self.key)) + .first() + ): db.session.delete(entry) - db.session.flush() return True return False diff --git a/superset/commands/key_value/delete_expired.py b/superset/commands/key_value/delete_expired.py index e3e75bf0f1fa4..54991c7531d27 100644 --- a/superset/commands/key_value/delete_expired.py +++ b/superset/commands/key_value/delete_expired.py @@ -60,4 +60,3 @@ def delete_expired(self) -> None: ) .delete() ) - db.session.flush() diff --git a/superset/commands/key_value/upsert.py b/superset/commands/key_value/upsert.py index e5c6eb74258c5..32918d9b14396 100644 --- a/superset/commands/key_value/upsert.py +++ b/superset/commands/key_value/upsert.py @@ -21,6 +21,8 @@ from typing import Any, Optional, Union from uuid import UUID +from sqlalchemy.exc import SQLAlchemyError + from superset import db from superset.commands.base import BaseCommand from superset.commands.key_value.create import CreateKeyValueCommand @@ -71,7 +73,7 @@ def __init__( # pylint: disable=too-many-arguments @transaction( on_error=partial( on_error, - catches=(KeyValueCreateFailedError,), + catches=(KeyValueCreateFailedError, SQLAlchemyError), reraise=KeyValueUpsertFailedError, ), ) @@ -82,16 +84,15 @@ def validate(self) -> None: pass def upsert(self) -> Key: - filter_ = get_filter(self.resource, self.key) - entry: KeyValueEntry = ( - db.session.query(KeyValueEntry).filter_by(**filter_).first() - ) - if entry: + if ( + entry := db.session.query(KeyValueEntry) + .filter_by(**get_filter(self.resource, self.key)) + .first() + ): entry.value = self.codec.encode(self.value) entry.expires_on = self.expires_on entry.changed_on = datetime.now() entry.changed_by_fk = get_user_id() - db.session.flush() return Key(entry.id, entry.uuid) return CreateKeyValueCommand( diff --git a/superset/commands/report/execute.py b/superset/commands/report/execute.py index 000c87f514a1d..c57828eac497b 100644 --- a/superset/commands/report/execute.py +++ b/superset/commands/report/execute.py @@ -137,6 +137,7 @@ def create_log(self, error_message: Optional[str] = None) -> None: uuid=self._execution_id, ) db.session.add(log) + db.session.commit() # pylint: disable=consider-using-transaction def _get_url( self, diff --git a/superset/commands/report/log_prune.py b/superset/commands/report/log_prune.py index 493c16ed77b80..a780bf51e0333 100644 --- a/superset/commands/report/log_prune.py +++ b/superset/commands/report/log_prune.py @@ -49,7 +49,6 @@ def run(self) -> None: report_schedule, from_date, ) - db.session.commit() logger.info( "Deleted %s logs for report schedule id: %s", str(row_count), diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py index 8392b7c139224..cf7af4b2f2785 100644 --- a/superset/commands/sql_lab/execute.py +++ b/superset/commands/sql_lab/execute.py @@ -144,7 +144,6 @@ def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus: self._execution_context.set_database(self._get_the_query_db()) query = self._execution_context.create_query() self._save_new_query(query) - db.session.flush() try: logger.info("Triggering query_id: %i", query.id) @@ -182,6 +181,17 @@ def _validate_query_db(cls, database: Database | None) -> None: ) def _save_new_query(self, query: Query) -> None: + """ + Saves the new SQL Lab query. + + Committing within a transaction violates the "unit of work" construct, but is + necessary for async querying given how the command is currently defined. + + To mitigate said issue, ideally there would be a command to prepare said query + and another to execute it, either in a sync or async manner. + + :param query: The SQL Lab query + """ try: self._query_dao.create(query) except SQLAlchemyError as ex: @@ -193,6 +203,8 @@ def _save_new_query(self, query: Query) -> None: "Please contact an administrator for further assistance or try again.", ) from ex + db.session.commit() # pylint: disable=consider-using-transaction + def _validate_access(self, query: Query) -> None: try: self._access_validator.validate(query) diff --git a/superset/daos/tag.py b/superset/daos/tag.py index 98c83dbe8e7f7..b155cf15c1522 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -383,5 +383,4 @@ def create_tag_relationship( object_id, tag.name, ) - db.session.add_all(tagged_objects) diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 3fe557a6843bc..823bfdfa8cc8b 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -32,7 +32,7 @@ from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper -from superset import is_feature_enabled, thumbnail_cache +from superset import db, is_feature_enabled, thumbnail_cache from superset.charts.schemas import ChartEntityResponseSchema from superset.commands.dashboard.create import CreateDashboardCommand from superset.commands.dashboard.delete import DeleteDashboardCommand @@ -1314,7 +1314,13 @@ def set_embedded(self, dashboard: Dashboard) -> Response: """ try: body = self.embedded_config_schema.load(request.json) - embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"]) + + with db.session.begin_nested(): + embedded = EmbeddedDashboardDAO.upsert( + dashboard, + body["allowed_domains"], + ) + result = self.embedded_response_schema.dump(embedded) return self.response(200, result=result) except ValidationError as error: diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py index 7b4e39677e48f..1c89e8459774d 100644 --- a/superset/extensions/metastore_cache.py +++ b/superset/extensions/metastore_cache.py @@ -22,7 +22,6 @@ from flask import current_app, Flask, has_app_context from flask_caching import BaseCache -from superset import db from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.types import ( KeyValueCodec, @@ -95,7 +94,6 @@ def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: codec=self.codec, expires_on=self._get_expiry(timeout), ).run() - db.session.commit() return True def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: @@ -111,7 +109,6 @@ def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: key=self.get_key(key), expires_on=self._get_expiry(timeout), ).run() - db.session.commit() return True except KeyValueCreateFailedError: return False @@ -136,6 +133,4 @@ def delete(self, key: str) -> Any: # pylint: disable=import-outside-toplevel from superset.commands.key_value.delete import DeleteKeyValueCommand - ret = DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() - db.session.commit() - return ret + return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() diff --git a/superset/key_value/shared_entries.py b/superset/key_value/shared_entries.py index f472838d2e090..130313157a53d 100644 --- a/superset/key_value/shared_entries.py +++ b/superset/key_value/shared_entries.py @@ -18,7 +18,6 @@ from typing import Any, Optional from uuid import uuid3 -from superset import db from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey from superset.key_value.utils import get_uuid_namespace, random_key @@ -46,7 +45,6 @@ def set_shared_value(key: SharedKey, value: Any) -> None: key=uuid_key, codec=CODEC, ).run() - db.session.commit() def get_permalink_salt(key: SharedKey) -> str: diff --git a/superset/security/manager.py b/superset/security/manager.py index a76400efca4de..b4bc0c6103def 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -48,7 +48,7 @@ from jwt.api_jwt import _jwt_global_obj from sqlalchemy import and_, inspect, or_ from sqlalchemy.engine.base import Connection -from sqlalchemy.orm import eagerload, Session +from sqlalchemy.orm import eagerload from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import Query as SqlaQuery @@ -333,24 +333,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods guest_user_cls = GuestUser pyjwt_for_guest_token = _jwt_global_obj - @property - def get_session(self) -> Session: - """ - Flask-AppBuilder (FAB) which has a tendency to explicitly commit, thus violating - our definition of "unit of work". - - By providing a monkey patched transaction for the FAB session ensures that any - explicit commit merely flushes and any rollback is a no-op. - """ - - # pylint: disable=import-outside-toplevel - from superset import db - - with db.session.begin_nested() as transaction: - transaction.session.commit = transaction.session.flush - transaction.session.rollback = lambda: None - return transaction.session - def create_login_manager(self, app: Flask) -> LoginManager: lm = super().create_login_manager(app) lm.request_loader(self.request_loader) @@ -1062,7 +1044,6 @@ def sync_role_definitions(self) -> None: self.auth_role_public, merge=True, ) - self.create_missing_perms() self.clean_perms() @@ -1136,7 +1117,6 @@ def copy_role( ): role_from_permissions.append(permission_view) role_to.permissions = role_from_permissions - self.get_session.flush() def set_role( self, @@ -1157,7 +1137,6 @@ def set_role( permission_view for permission_view in pvms if pvm_check(permission_view) ] role.permissions = role_pvms - self.get_session.flush() def _is_admin_only(self, pvm: PermissionView) -> bool: """ @@ -2458,9 +2437,6 @@ def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]: RowLevelSecurityFilter, ) - print(">>> get_rls_filters <<<") - print(g.user) - print(self.get_user_roles(g.user)) user_roles = [role.id for role in self.get_user_roles(g.user)] regular_filter_roles = ( self.get_session.query(RLSFilterRoles.c.rls_filter_id) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index d4af792646fde..9712ab47ab426 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -128,6 +128,7 @@ def handle_query_error( def get_query_backoff_handler(details: dict[Any, Any]) -> None: + print(details) query_id = details["kwargs"]["query_id"] logger.error( "Query with id `%s` could not be retrieved", str(query_id), exc_info=True diff --git a/superset/sqllab/sql_json_executer.py b/superset/sqllab/sql_json_executer.py index fde73aef0a86e..ac9968ed6b467 100644 --- a/superset/sqllab/sql_json_executer.py +++ b/superset/sqllab/sql_json_executer.py @@ -90,6 +90,7 @@ def execute( rendered_query: str, log_params: dict[str, Any] | None, ) -> SqlJsonExecutionStatus: + print(">>> execute <<<") query_id = execution_context.query.id try: data = self._get_sql_results_with_timeout( @@ -101,6 +102,7 @@ def execute( raise except Exception as ex: logger.exception("Query %i failed unexpectedly", query_id) + print(str(ex)) raise SupersetGenericDBErrorException( utils.error_msg_from_exception(ex) ) from ex @@ -112,6 +114,7 @@ def execute( [SupersetError(**params) for params in data["errors"]] # type: ignore ) # old string-only error message + print(data) raise SupersetGenericDBErrorException(data["error"]) # type: ignore return SqlJsonExecutionStatus.HAS_RESULTS diff --git a/superset/tags/models.py b/superset/tags/models.py index 8c3e53b31488a..31975c3e8e882 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -140,6 +140,7 @@ def get_tag( if tag is None: tag = Tag(name=escape(tag_name), type=type_) session.add(tag) + session.commit() return tag diff --git a/superset/utils/database.py b/superset/utils/database.py index 7ed315650226e..719e7f2d772c7 100644 --- a/superset/utils/database.py +++ b/superset/utils/database.py @@ -59,6 +59,7 @@ def get_or_create_db( if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri: database.set_sqlalchemy_uri(sqlalchemy_uri) + db.session.flush() return database @@ -78,3 +79,4 @@ def remove_database(database: Database) -> None: from superset import db db.session.delete(database) + db.session.flush() diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 30c668bab5556..844a8f063c1b8 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -222,7 +222,7 @@ def on_error( :param ex: The source exception :param catches: The exception types the handler catches - :param reraise: The exception type the handler reraises after catching + :param reraise: The exception type the handler raises after catching :raises Exception: If the exception is not swallowed """ @@ -240,7 +240,11 @@ def transaction( # pylint: disable=redefined-outer-name on_error: Callable[..., Any] | None = on_error, ) -> Callable[..., Any]: """ - Perform a "unit of work" by leveraging SQLAlchemy's nested transaction. + Perform a "unit of work". + + Note ideally this would leverage SQLAlchemy's nested transaction, however this + proved rather complicated, likely due to many architectural facets, and thus has + been left for a follow up exercise. :param on_error: Callback invoked when an exception is caught :see: https://github.com/apache/superset/issues/25108 @@ -252,13 +256,16 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: from superset import db # pylint: disable=import-outside-toplevel try: - with db.session.begin_nested(): - return func(*args, **kwargs) + result = func(*args, **kwargs) + db.session.commit() # pylint: disable=consider-using-transaction + return result except Exception as ex: + db.session.rollback() # pylint: disable=consider-using-transaction + if on_error: return on_error(ex) - raise ex + raise return wrapped diff --git a/superset/utils/lock.py b/superset/utils/lock.py index 3cd3c8ead53ab..4723b57fa1b01 100644 --- a/superset/utils/lock.py +++ b/superset/utils/lock.py @@ -24,7 +24,6 @@ from datetime import datetime, timedelta from typing import Any, cast, TypeVar, Union -from superset import db from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.types import JsonKeyValueCodec, KeyValueResource @@ -72,7 +71,6 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name store. :param namespace: The namespace for which the lock is to be acquired. - :type namespace: str :param kwargs: Additional keyword arguments. :yields: A unique identifier (UUID) for the acquired lock (the KV key). :raises CreateKeyValueDistributedLockFailedException: If the lock is taken. @@ -93,12 +91,10 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name value=True, expires_on=datetime.now() + LOCK_EXPIRATION, ).run() - db.session.commit() yield key DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run() - db.session.commit() logger.debug("Removed lock on namespace %s for key %s", namespace, key) except KeyValueCreateFailedError as ex: raise CreateKeyValueDistributedLockFailedException( diff --git a/superset/utils/log.py b/superset/utils/log.py index 730bb7c43fbb0..71c552883307d 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -403,6 +403,7 @@ def log( # pylint: disable=too-many-arguments,too-many-locals logs.append(log) try: db.session.bulk_save_objects(logs) + db.session.commit() # pylint: disable=consider-using-transaction except SQLAlchemyError as ex: logging.error("DBEventLogger failed to log event(s)") logging.exception(ex) diff --git a/superset/views/database/views.py b/superset/views/database/views.py index d2ccd49ba5714..019dc1138bd11 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=consider-using-transaction -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING from flask_appbuilder import expose from flask_appbuilder.models.sqla.interface import SQLAInterface diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 3f4cab16ad8bc..77633d65642e9 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -349,15 +349,7 @@ def grant_public_access_to_table(self, table): self.grant_role_access_to_table(table, role_name) def grant_role_access_to_table(self, table, role_name): - print(">>> grant_role_access_to_table <<<") - print(role_name) - print(db.session.get_bind()) - from flask_appbuilder.security.sqla.models import Role - - print(list(db.session.query(Role).all())) role = security_manager.find_role(role_name) - print(role) - perms = db.session.query(ab_models.PermissionView).all() for perm in perms: if ( diff --git a/tests/integration_tests/celery_tests.py b/tests/integration_tests/celery_tests.py index 53690443d59e7..6301365387224 100644 --- a/tests/integration_tests/celery_tests.py +++ b/tests/integration_tests/celery_tests.py @@ -267,8 +267,10 @@ def test_run_async_query_cta_config(test_client, ctas_method): async_=True, tmp_table=tmp_table_name, ) + print(result) query = wait_for_success(result) + print(query.to_dict()) assert QueryStatus.SUCCESS == query.status assert ( @@ -323,6 +325,9 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method): tmp_table=tmp_table, ) query = wait_for_success(result) + print(">>> test_run_async_cta_query_with_lower_limit <<<") + print(result) + print(query.to_dict()) assert QueryStatus.SUCCESS == query.status sqlite_select_sql = f"SELECT\n *\nFROM {tmp_table}\nLIMIT {query.limit}\nOFFSET 0" diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index c5d7700c6de28..44b7ef26e64cd 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -91,7 +91,7 @@ def setUp(self): self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] def tearDown(self): - # db.session.query(Query).delete() + db.session.query(Query).delete() app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting super().tearDown() @@ -235,7 +235,6 @@ def test_save_slice(self): ) for slc in slices: db.session.delete(slc) - print(db.session.dirty) db.session.commit() @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @@ -668,7 +667,6 @@ def test_explore_json_dist_bar_order(self): client_id="client_id_1", username="admin", ) - print(resp) count_ds = [] count_name = [] for series in data["data"]: @@ -816,7 +814,7 @@ def set(self): mock_cache.return_value = MockCache() rv = self.client.get("/superset/explore_json/data/valid-cache-key") - self.assertEqual(rv.status_code, 401) + self.assertEqual(rv.status_code, 403) def test_explore_json_data_invalid_cache_key(self): self.login(ADMIN_USERNAME) diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index 3a0c3ef21720a..1b900ecbcc8a2 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -97,5 +97,6 @@ def create_dashboard( if slices is not None: dash.slices = slices db.session.add(dash) + db.session.commit() return dash diff --git a/tests/integration_tests/dashboards/commands_tests.py b/tests/integration_tests/dashboards/commands_tests.py index 06edd6c6d0f18..334e0425cf1f3 100644 --- a/tests/integration_tests/dashboards/commands_tests.py +++ b/tests/integration_tests/dashboards/commands_tests.py @@ -592,7 +592,6 @@ def test_import_v1_dashboard_multiple(self, mock_g): } command = v1.ImportDashboardsCommand(contents, overwrite=True) command.run() - command.run() new_num_dashboards = db.session.query(Dashboard).count() assert new_num_dashboards == num_dashboards + 1 diff --git a/tests/integration_tests/dashboards/dao_tests.py b/tests/integration_tests/dashboards/dao_tests.py index eb9207423e73a..83ef02730b0a8 100644 --- a/tests/integration_tests/dashboards/dao_tests.py +++ b/tests/integration_tests/dashboards/dao_tests.py @@ -48,15 +48,16 @@ def test_get_dashboard_changed_on(self, mock_sm_g, mock_g): assert changed_on == DashboardDAO.get_dashboard_changed_on("world_health") old_changed_on = dashboard.changed_on + # freezegun doesn't work for some reason, so we need to sleep here :( time.sleep(1) data = dashboard.data positions = data["position_json"] data.update({"positions": positions}) original_data = copy.deepcopy(data) + data.update({"foo": "bar"}) DashboardDAO.set_dash_metadata(dashboard, data) - db.session.flush() db.session.commit() new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard) assert old_changed_on.replace(microsecond=0) < new_changed_on @@ -68,7 +69,6 @@ def test_get_dashboard_changed_on(self, mock_sm_g, mock_g): ) DashboardDAO.set_dash_metadata(dashboard, original_data) - db.session.flush() db.session.commit() @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 9f9882c99bde6..5de531064d8c3 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -281,10 +281,11 @@ def test_create_database(self): "server_cert": None, "extra": json.dumps(extra), } - + print(database_data) uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) + print(response) self.assertEqual(rv.status_code, 201) # Cleanup model = db.session.query(Database).get(response.get("id")) @@ -700,6 +701,7 @@ def test_cascade_delete_ssh_tunnel( mock_create_is_feature_enabled.return_value = True self.login(ADMIN_USERNAME) example_db = get_example_database() + print(example_db) if example_db.backend == "sqlite": return ssh_tunnel_properties = { @@ -713,10 +715,11 @@ def test_cascade_delete_ssh_tunnel( "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } - + print(database_data) uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) + print(response) self.assertEqual(rv.status_code, 201) model_ssh_tunnel = ( db.session.query(SSHTunnel) @@ -839,16 +842,12 @@ def test_get_database_returns_related_ssh_tunnel( db.session.delete(model) db.session.commit() - @mock.patch( - "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", - ) @mock.patch("superset.models.core.Database.get_all_catalog_names") @mock.patch("superset.models.core.Database.get_all_schema_names") def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception( self, mock_get_all_schema_names, mock_get_all_catalog_names, - mock_test_connection_database_command_run, ): """ Database API: Test raises SSHTunneling feature flag not enabled @@ -918,6 +917,7 @@ def test_create_database_invalid_configuration_method(self): self.login(ADMIN_USERNAME) example_db = get_example_database() + print(example_db) if example_db.backend == "sqlite": return database_data = { @@ -927,10 +927,11 @@ def test_create_database_invalid_configuration_method(self): "server_cert": None, "extra": json.dumps(extra), } - + print(database_data) uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) + print(response) assert response == { "message": { "configuration_method": [ @@ -1890,7 +1891,9 @@ def test_get_allow_file_upload_false_csv(self): } uri = f"api/v1/database/?q={prison.dumps(arguments)}" rv = self.client.get(uri) + print(rv) data = json.loads(rv.data.decode("utf-8")) + print(data) assert data["count"] == 1 def test_get_allow_file_upload_filter_no_permission(self): diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 05942ec22aab9..84d8a44066c4c 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -26,6 +26,7 @@ import pytest import yaml from sqlalchemy import inspect +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload from sqlalchemy.sql import func diff --git a/tests/integration_tests/embedded/api_tests.py b/tests/integration_tests/embedded/api_tests.py index 533f1311d3d6c..64afaa178496a 100644 --- a/tests/integration_tests/embedded/api_tests.py +++ b/tests/integration_tests/embedded/api_tests.py @@ -44,6 +44,7 @@ def test_get_embedded_dashboard(self): self.login(ADMIN_USERNAME) self.dash = db.session.query(Dashboard).filter_by(slug="births").first() self.embedded = EmbeddedDashboardDAO.upsert(self.dash, []) + db.session.flush() uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}" response = self.client.get(uri) self.assert200(response) diff --git a/tests/integration_tests/embedded/dao_tests.py b/tests/integration_tests/embedded/dao_tests.py index e1f72feb89db8..eed161581fe71 100644 --- a/tests/integration_tests/embedded/dao_tests.py +++ b/tests/integration_tests/embedded/dao_tests.py @@ -34,17 +34,21 @@ def test_upsert(self): dash = db.session.query(Dashboard).filter_by(slug="world_health").first() assert not dash.embedded EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() assert dash.embedded self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"]) original_uuid = dash.embedded[0].uuid self.assertIsNotNone(original_uuid) EmbeddedDashboardDAO.upsert(dash, []) + db.session.flush() self.assertEqual(dash.embedded[0].allowed_domains, []) self.assertEqual(dash.embedded[0].uuid, original_uuid) @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_get_by_uuid(self): dash = db.session.query(Dashboard).filter_by(slug="world_health").first() - uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid) + EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() + uuid = str(dash.embedded[0].uuid) embedded = EmbeddedDashboardDAO.find_by_id(uuid) self.assertIsNotNone(embedded) diff --git a/tests/integration_tests/embedded/test_view.py b/tests/integration_tests/embedded/test_view.py index 7fcfcdba9ff0e..f4d5ae6925568 100644 --- a/tests/integration_tests/embedded/test_view.py +++ b/tests/integration_tests/embedded/test_view.py @@ -44,6 +44,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811 dash = db.session.query(Dashboard).filter_by(slug="births").first() embedded = EmbeddedDashboardDAO.upsert(dash, []) + db.session.flush() uri = f"embedded/{embedded.uuid}" response = client.get(uri) assert response.status_code == 200 @@ -57,6 +58,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811 def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811 dash = db.session.query(Dashboard).filter_by(slug="births").first() embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() uri = f"embedded/{embedded.uuid}" response = client.get(uri) assert response.status_code == 403 diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py index 3fe2de5944691..513a9f84a24e4 100644 --- a/tests/integration_tests/fixtures/birth_names_dashboard.py +++ b/tests/integration_tests/fixtures/birth_names_dashboard.py @@ -46,22 +46,25 @@ def load_birth_names_data( @pytest.fixture() def load_birth_names_dashboard_with_slices(load_birth_names_data): with app.app_context(): - _create_dashboards() + dash_id_to_delete, slices_ids_to_delete = _create_dashboards() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) @pytest.fixture(scope="module") def load_birth_names_dashboard_with_slices_module_scope(load_birth_names_data): with app.app_context(): - _create_dashboards() + dash_id_to_delete, slices_ids_to_delete = _create_dashboards() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) @pytest.fixture(scope="class") def load_birth_names_dashboard_with_slices_class_scope(load_birth_names_data): with app.app_context(): - _create_dashboards() + dash_id_to_delete, slices_ids_to_delete = _create_dashboards() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) def _create_dashboards(): @@ -74,7 +77,10 @@ def _create_dashboards(): from superset.examples.birth_names import create_dashboard, create_slices slices, _ = create_slices(table) - create_dashboard(slices) + dash = create_dashboard(slices) + slices_ids_to_delete = [slice.id for slice in slices] + dash_id_to_delete = dash.id + return dash_id_to_delete, slices_ids_to_delete def _create_table( @@ -91,4 +97,20 @@ def _create_table( _set_table_metadata(table, database) _add_table_metrics(table) + db.session.commit() return table + + +def _cleanup(dash_id: int, slice_ids: list[int]) -> None: + schema = get_example_default_schema() + for datasource in db.session.query(SqlaTable).filter_by( + table_name="birth_names", schema=schema + ): + for col in datasource.columns + datasource.metrics: + db.session.delete(col) + + for dash in db.session.query(Dashboard).filter_by(id=dash_id): + db.session.delete(dash) + for slc in db.session.query(Slice).filter(Slice.id.in_(slice_ids)): + db.session.delete(slc) + db.session.commit() diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py index dfac9644ae61f..5d938e05416ca 100644 --- a/tests/integration_tests/fixtures/energy_dashboard.py +++ b/tests/integration_tests/fixtures/energy_dashboard.py @@ -61,6 +61,7 @@ def load_energy_table_with_slice(load_energy_table_data): with app.app_context(): slices = _create_energy_table() yield slices + _cleanup() def _get_dataframe(): @@ -109,6 +110,24 @@ def _create_and_commit_energy_slice( return slice +def _cleanup() -> None: + for slice_data in _get_energy_slices(): + slice = ( + db.session.query(Slice) + .filter_by(slice_name=slice_data["slice_title"]) + .first() + ) + db.session.delete(slice) + + metric = ( + db.session.query(SqlMetric).filter_by(metric_name="sum__value").one_or_none() + ) + if metric: + db.session.delete(metric) + + db.session.commit() + + def _get_energy_data(): data = [] for i in range(85): diff --git a/tests/integration_tests/fixtures/tabbed_dashboard.py b/tests/integration_tests/fixtures/tabbed_dashboard.py index cf5b9f109cd0a..d4ddff5796348 100644 --- a/tests/integration_tests/fixtures/tabbed_dashboard.py +++ b/tests/integration_tests/fixtures/tabbed_dashboard.py @@ -135,4 +135,7 @@ def tabbed_dashboard(app_context): slices=[], ) db.session.add(dash) - yield + db.session.commit() + yield dash + db.session.query(Dashboard).filter_by(id=dash.id).delete() + db.session.commit() diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py b/tests/integration_tests/fixtures/unicode_dashboard.py index e123279e75218..970845783058c 100644 --- a/tests/integration_tests/fixtures/unicode_dashboard.py +++ b/tests/integration_tests/fixtures/unicode_dashboard.py @@ -61,6 +61,7 @@ def load_unicode_dashboard_with_slice(load_unicode_data): with app.app_context(): dash = _create_unicode_dashboard(slice_name, None) yield + _cleanup(dash, slice_name) @pytest.fixture() @@ -70,6 +71,7 @@ def load_unicode_dashboard_with_position(load_unicode_data): with app.app_context(): dash = _create_unicode_dashboard(slice_name, position) yield + _cleanup(dash, slice_name) def _get_dataframe(): @@ -95,18 +97,25 @@ def _create_unicode_dashboard(slice_title: str, position: str) -> Dashboard: table.fetch_metadata() if slice_title: - slice = _create_unicode_slice(table, slice_title) + slice = _create_and_commit_unicode_slice(table, slice_title) return create_dashboard("unicode-test", "Unicode Test", position, [slice]) -def _create_unicode_slice(table: SqlaTable, title: str): - slc = create_slice(title, "word_cloud", table, {}) - if ( - obj := db.session.query(Slice) - .filter_by(slice_name=slc.slice_name) - .one_or_none() +def _create_and_commit_unicode_slice(table: SqlaTable, title: str): + slice = create_slice(title, "word_cloud", table, {}) + o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none() + if o: + db.session.delete(o) + db.session.add(slice) + db.session.commit() + return slice + + +def _cleanup(dash: Dashboard, slice_name: str) -> None: + db.session.delete(dash) + if slice_name and ( + slice := db.session.query(Slice).filter_by(slice_name=slice_name).one_or_none() ): - db.session.delete(obj) - db.session.add(slc) - return slc + db.session.delete(slice) + db.session.commit() diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index 99f8b573757d5..6e2b408600489 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -64,7 +64,6 @@ def load_world_bank_data(): ) yield - with app.app_context(): with get_example_database().get_sqla_engine() as engine: engine.execute("DROP TABLE IF EXISTS wb_health_population") @@ -73,14 +72,15 @@ def load_world_bank_data(): @pytest.fixture() def load_world_bank_dashboard_with_slices(load_world_bank_data): with app.app_context(): - create_dashboard_for_loaded_data() + dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) @pytest.fixture(scope="module") def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data): with app.app_context(): - create_dashboard_for_loaded_data() + dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data() yield _cleanup_reports(dash_id_to_delete, slices_ids_to_delete) _cleanup(dash_id_to_delete, slices_ids_to_delete) @@ -89,8 +89,9 @@ def load_world_bank_dashboard_with_slices_module_scope(load_world_bank_data): @pytest.fixture(scope="class") def load_world_bank_dashboard_with_slices_class_scope(load_world_bank_data): with app.app_context(): - create_dashboard_for_loaded_data() + dash_id_to_delete, slices_ids_to_delete = create_dashboard_for_loaded_data() yield + _cleanup(dash_id_to_delete, slices_ids_to_delete) def create_dashboard_for_loaded_data(): @@ -102,22 +103,24 @@ def create_dashboard_for_loaded_data(): return dash_id_to_delete, slices_ids_to_delete -def _create_world_bank_slices(table: SqlaTable) -> None: +def _create_world_bank_slices(table: SqlaTable) -> list[Slice]: from superset.examples.world_bank import create_slices slices = create_slices(table) + _commit_slices(slices) + return slices + - for slc in slices: - if ( - obj := db.session.query(Slice) - .filter_by(slice_name=slc.slice_name) - .one_or_none() - ): - db.session.delete(obj) - db.session.add(slc) +def _commit_slices(slices: list[Slice]): + for slice in slices: + o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none() + if o: + db.session.delete(o) + db.session.add(slice) + db.session.commit() -def _create_world_bank_dashboard(table: SqlaTable) -> None: +def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard: from superset.examples.helpers import update_slice_ids from superset.examples.world_bank import dashboard_positions @@ -130,6 +133,16 @@ def _create_world_bank_dashboard(table: SqlaTable) -> None: "world_health", "World Bank's Data", json.dumps(pos), slices ) dash.json_metadata = '{"mock_key": "mock_value"}' + db.session.commit() + return dash + + +def _cleanup(dash_id: int, slices_ids: list[int]) -> None: + dash = db.session.query(Dashboard).filter_by(id=dash_id).first() + db.session.delete(dash) + for slice_id in slices_ids: + db.session.query(Slice).filter_by(id=slice_id).delete() + db.session.commit() def _cleanup_reports(dash_id: int, slices_ids: list[int]) -> None: diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 2c8a13a71f4d6..71bb1484e0330 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -215,8 +215,6 @@ def test_model_view_rls_add_name_unique(self): }, ) self.assertEqual(rv.status_code, 422) - data = json.loads(rv.data.decode("utf-8")) - assert "Create failed" in data["message"] @pytest.mark.usefixtures("create_dataset") def test_model_view_rls_add_tables_required(self): diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index a36cb8a8ec35a..829854d966810 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -73,7 +73,6 @@ class TestSqlLab(SupersetTestCase): def run_some_queries(self): db.session.query(Query).delete() - db.session.commit() self.run_sql(QUERY_1, client_id="client_id_1", username="admin") self.run_sql(QUERY_2, client_id="client_id_2", username="admin") self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab") diff --git a/tests/integration_tests/tags/dao_tests.py b/tests/integration_tests/tags/dao_tests.py index dbd0360aa75fe..8a6ba6e5f4b3a 100644 --- a/tests/integration_tests/tags/dao_tests.py +++ b/tests/integration_tests/tags/dao_tests.py @@ -187,6 +187,7 @@ def test_get_objects_from_tag(self): TaggedObject.object_type == ObjectType.chart, ), ) + .join(Tag, TaggedObject.tag_id == Tag.id) .distinct(Slice.id) .count() ) @@ -199,6 +200,7 @@ def test_get_objects_from_tag(self): TaggedObject.object_type == ObjectType.dashboard, ), ) + .join(Tag, TaggedObject.tag_id == Tag.id) .distinct(Dashboard.id) .count() + num_charts diff --git a/tests/unit_tests/commands/databases/create_test.py b/tests/unit_tests/commands/databases/create_test.py index 405238827d5cf..09d5744afd53b 100644 --- a/tests/unit_tests/commands/databases/create_test.py +++ b/tests/unit_tests/commands/databases/create_test.py @@ -29,7 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database with catalogs and schemas. """ - mocker.patch("superset.commands.database.create.db") mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") database = mocker.MagicMock() @@ -53,7 +52,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database without catalogs. """ - mocker.patch("superset.commands.database.create.db") mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") database = mocker.MagicMock() diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index 300efb62e7d3c..37500d521420a 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -29,8 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database with catalogs and schemas. """ - mocker.patch("superset.commands.database.update.db") - database = mocker.MagicMock() database.database_name = "my_db" database.db_engine_spec.__name__ = "test_engine" @@ -50,8 +48,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database without catalogs. """ - mocker.patch("superset.commands.database.update.db") - database = mocker.MagicMock() database.database_name = "my_db" database.db_engine_spec.__name__ = "test_engine" diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py index d1907fb6cb982..7662393d4fc49 100644 --- a/tests/unit_tests/dao/tag_test.py +++ b/tests/unit_tests/dao/tag_test.py @@ -22,7 +22,6 @@ def test_user_favorite_tag(mocker): from superset.daos.tag import TagDAO # Mock the behavior of TagDAO and g - mock_session = mocker.patch("superset.daos.tag.db.session") mock_TagDAO = mocker.patch( "superset.daos.tag.TagDAO" ) # Replace with the actual path to TagDAO @@ -45,7 +44,6 @@ def test_remove_user_favorite_tag(mocker): from superset.daos.tag import TagDAO # Mock the behavior of TagDAO and g - mock_session = mocker.patch("superset.daos.tag.db.session") mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO") mock_tag = mocker.MagicMock(users_favorited=[]) mock_TagDAO.find_by_id.return_value = mock_tag diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 6eeb7ff162894..f4534d216b9b7 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -116,7 +116,7 @@ def test_post_with_uuid( assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb" database = session.query(Database).one() - assert database.uuid == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb" + assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb") def test_password_mask( diff --git a/tests/unit_tests/utils/lock_tests.py b/tests/unit_tests/utils/lock_tests.py index aa231bb0cf8f2..4c9121fe38744 100644 --- a/tests/unit_tests/utils/lock_tests.py +++ b/tests/unit_tests/utils/lock_tests.py @@ -22,8 +22,8 @@ import pytest from freezegun import freeze_time -from sqlalchemy.orm import Session, sessionmaker +from superset import db from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.key_value.types import JsonKeyValueCodec from superset.utils.lock import get_key, KeyValueDistributedLock @@ -32,56 +32,51 @@ OTHER_KEY = get_key("ns2", a=1, b=2) -def _get_lock(key: UUID, session: Session) -> Any: +def _get_lock(key: UUID) -> Any: from superset.key_value.models import KeyValueEntry - entry = session.query(KeyValueEntry).filter_by(uuid=key).first() + entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first() if entry is None or entry.is_expired(): return None return JsonKeyValueCodec().decode(entry.value) -def _get_other_session() -> Session: - # This session is used to simulate what another worker will find in the metastore - # during the locking process. - from superset import db - - bind = db.session.get_bind() - SessionMaker = sessionmaker(bind=bind) - return SessionMaker() - - def test_key_value_distributed_lock_happy_path() -> None: """ Test successfully acquiring and returning the distributed lock. + + Note we use a nested transaction to ensure that the cleanup from the outer context + manager is correctly invoked, otherwise a partial rollback would occur leaving the + database in a fractured state. """ - session = _get_other_session() with freeze_time("2021-01-01"): - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is None + with KeyValueDistributedLock("ns", a=1, b=2) as key: assert key == MAIN_KEY - assert _get_lock(key, session) is True - assert _get_lock(OTHER_KEY, session) is None - with pytest.raises(CreateKeyValueDistributedLockFailedException): - with KeyValueDistributedLock("ns", a=1, b=2): - pass + assert _get_lock(key) is True + assert _get_lock(OTHER_KEY) is None + + with db.session.begin_nested(): + with pytest.raises(CreateKeyValueDistributedLockFailedException): + with KeyValueDistributedLock("ns", a=1, b=2): + pass - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is None def test_key_value_distributed_lock_expired() -> None: """ Test expiration of the distributed lock """ - session = _get_other_session() - with freeze_time("2021-01-01T"): - assert _get_lock(MAIN_KEY, session) is None + with freeze_time("2021-01-01"): + assert _get_lock(MAIN_KEY) is None with KeyValueDistributedLock("ns", a=1, b=2): - assert _get_lock(MAIN_KEY, session) is True - with freeze_time("2022-01-01T"): - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is True + with freeze_time("2022-01-01"): + assert _get_lock(MAIN_KEY) is None - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is None