diff --git a/superset/cli/viz_migrations.py b/superset/cli/viz_migrations.py index ff211f526b902..58e82ea196bc9 100644 --- a/superset/cli/viz_migrations.py +++ b/superset/cli/viz_migrations.py @@ -17,11 +17,31 @@ from __future__ import annotations from enum import Enum +from typing import Type import click +from click_option_group import optgroup, RequiredAnyOptionGroup from flask.cli import with_appcontext from superset import db +from superset.migrations.shared.migrate_viz.base import ( + MigrateViz, + Slice, +) +from superset.migrations.shared.migrate_viz.processors import ( + MigrateAreaChart, + MigrateBarChart, + MigrateBubbleChart, + MigrateDistBarChart, + MigrateDualLine, + MigrateHeatmapChart, + MigrateHistogramChart, + MigrateLineChart, + MigratePivotTable, + MigrateSunburst, + MigrateTreeMap, +) +from superset.migrations.shared.utils import paginated_update class VizType(str, Enum): @@ -38,6 +58,25 @@ class VizType(str, Enum): TREEMAP = "treemap" +MIGRATIONS: dict[VizType, Type[MigrateViz]] = { + VizType.AREA: MigrateAreaChart, + VizType.BAR: MigrateBarChart, + VizType.BUBBLE: MigrateBubbleChart, + VizType.DIST_BAR: MigrateDistBarChart, + VizType.DUAL_LINE: MigrateDualLine, + VizType.HEATMAP: MigrateHeatmapChart, + VizType.HISTOGRAM: MigrateHistogramChart, + VizType.LINE: MigrateLineChart, + VizType.PIVOT_TABLE: MigratePivotTable, + VizType.SUNBURST: MigrateSunburst, + VizType.TREEMAP: MigrateTreeMap, +} + +PREVIOUS_VERSION = { + migration.target_viz_type: migration for migration in MIGRATIONS.values() +} + + @click.group() def migrate_viz() -> None: """ @@ -47,73 +86,82 @@ def migrate_viz() -> None: @migrate_viz.command() @with_appcontext -@click.option( +@optgroup.group( + cls=RequiredAnyOptionGroup, +) +@optgroup.option( "--viz_type", "-t", help=f"The viz type to upgrade: {', '.join(list(VizType))}", - required=True, + type=str, ) -@click.option( - "--chart_id", - help="The chart ID to upgrade", - type=int, +@optgroup.option( + "-ids", + help="A comma separated list of chart IDs to upgrade", + type=str, ) -def upgrade(viz_type: str, chart_id: int | None = None) -> None: +def upgrade(viz_type: str, ids: str | None = None) -> None: """Upgrade a viz to the latest version.""" - migrate(VizType(viz_type), chart_id) + if ids is None: + migrate_by_viz_type(VizType(viz_type)) + else: + migrate_by_ids(ids) @migrate_viz.command() @with_appcontext -@click.option( +@optgroup.group( + cls=RequiredAnyOptionGroup, +) +@optgroup.option( "--viz_type", "-t", help=f"The viz type to downgrade: {', '.join(list(VizType))}", - required=True, + type=str, ) -@click.option( - "--chart_id", - help="The chart ID to downgrade", - type=int, +@optgroup.option( + "-ids", + help="A comma separated list of chart IDs to downgrade", + type=str, ) -def downgrade(viz_type: str, chart_id: int | None = None) -> None: +def downgrade(viz_type: str, ids: str | None = None) -> None: """Downgrade a viz to the previous version.""" - migrate(VizType(viz_type), chart_id, is_downgrade=True) - - -def migrate( - viz_type: VizType, chart_id: int | None = None, is_downgrade: bool = False -) -> None: - """Migrate a viz from one type to another.""" - # pylint: disable=import-outside-toplevel - from superset.migrations.shared.migrate_viz.processors import ( - MigrateAreaChart, - MigrateBarChart, - MigrateBubbleChart, - MigrateDistBarChart, - MigrateDualLine, - MigrateHeatmapChart, - MigrateHistogramChart, - MigrateLineChart, - MigratePivotTable, - MigrateSunburst, - MigrateTreeMap, - ) - - migrations = { - VizType.AREA: MigrateAreaChart, - VizType.BAR: MigrateBarChart, - VizType.BUBBLE: MigrateBubbleChart, - VizType.DIST_BAR: MigrateDistBarChart, - VizType.DUAL_LINE: MigrateDualLine, - VizType.HEATMAP: MigrateHeatmapChart, - VizType.HISTOGRAM: MigrateHistogramChart, - VizType.LINE: MigrateLineChart, - VizType.PIVOT_TABLE: MigratePivotTable, - VizType.SUNBURST: MigrateSunburst, - VizType.TREEMAP: MigrateTreeMap, - } + if ids is None: + migrate_by_viz_type(VizType(viz_type), is_downgrade=True) + else: + migrate_by_ids(ids, is_downgrade=True) + + +def migrate_by_viz_type(viz_type: VizType, is_downgrade: bool = False) -> None: + """ + Migrate all charts of a viz type. + + :param viz_type: The viz type to migrate + :param is_downgrade: Whether to downgrade the charts. Default is upgrade. + """ + migration: Type[MigrateViz] = MIGRATIONS[viz_type] if is_downgrade: - migrations[viz_type].downgrade(db.session, chart_id) + migration.downgrade(db.session) else: - migrations[viz_type].upgrade(db.session, chart_id) + migration.upgrade(db.session) + + +def migrate_by_ids(ids: str, is_downgrade: bool = False) -> None: + """ + Migrate a subset of charts by a list of IDs. + + :param ids: List of chart IDs to migrate + :param is_downgrade: Whether to downgrade the charts. Default is upgrade. + """ + id_list = [int(i) for i in ids.split(",")] + slices = db.session.query(Slice).filter(Slice.id.in_(id_list)) + for slc in paginated_update( + slices, + lambda current, total: print( + f"{('Downgraded' if is_downgrade else 'Upgraded')} {current}/{total} charts" + ), + ): + if is_downgrade: + PREVIOUS_VERSION[slc.viz_type].downgrade_slice(slc) + else: + MIGRATIONS[slc.viz_type].upgrade_slice(slc) diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index 6198cbfba62e4..b013aa0be2b20 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import annotations - import copy from typing import Any @@ -153,13 +151,8 @@ def downgrade_slice(cls, slc: Slice) -> None: slc.query_context = json.dumps(query_context) @classmethod - def upgrade(cls, session: Session, chart_id: int | None = None) -> None: - slices = session.query(Slice).filter( - and_( - Slice.viz_type == cls.source_viz_type, - Slice.id == chart_id if chart_id is not None else True, - ) - ) + def upgrade(cls, session: Session) -> None: + slices = session.query(Slice).filter(Slice.viz_type == cls.source_viz_type) for slc in paginated_update( slices, lambda current, total: print(f"Upgraded {current}/{total} charts"), @@ -167,12 +160,11 @@ def upgrade(cls, session: Session, chart_id: int | None = None) -> None: cls.upgrade_slice(slc) @classmethod - def downgrade(cls, session: Session, chart_id: int | None = None) -> None: + def downgrade(cls, session: Session) -> None: slices = session.query(Slice).filter( and_( Slice.viz_type == cls.target_viz_type, Slice.params.like(f"%{FORM_DATA_BAK_FIELD_NAME}%"), - Slice.id == chart_id if chart_id is not None else True, ) ) for slc in paginated_update(