diff --git a/requirements/base.txt b/requirements/base.txt index 4dc66c1e4ade..400dca59d147 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -292,6 +292,8 @@ wtforms==2.3.3 # wtforms-json wtforms-json==0.3.3 # via apache-superset +xlsxwriter==3.0.7 + # via apache-superset # The following packages are considered to be unsafe in a requirements file: # setuptools diff --git a/setup.py b/setup.py index dc546e5a6030..448566d0bc63 100644 --- a/setup.py +++ b/setup.py @@ -124,6 +124,7 @@ def get_git_sha() -> str: "typing-extensions>=4, <5", "wtforms>=2.3.3, <2.4", "wtforms-json", + "xlsxwriter>=3.0.7, <3.1", ], extras_require={ "athena": ["pyathena[pandas]>=2, <3"], diff --git a/superset-frontend/src/explore/components/useExploreAdditionalActionsMenu/index.jsx b/superset-frontend/src/explore/components/useExploreAdditionalActionsMenu/index.jsx index 62fcdaaf15ac..445db6dc4414 100644 --- a/superset-frontend/src/explore/components/useExploreAdditionalActionsMenu/index.jsx +++ b/superset-frontend/src/explore/components/useExploreAdditionalActionsMenu/index.jsx @@ -41,6 +41,7 @@ const MENU_KEYS = { EXPORT_TO_CSV: 'export_to_csv', EXPORT_TO_CSV_PIVOTED: 'export_to_csv_pivoted', EXPORT_TO_JSON: 'export_to_json', + EXPORT_TO_XLSX: 'export_to_xlsx', DOWNLOAD_AS_IMAGE: 'download_as_image', SHARE_SUBMENU: 'share_submenu', COPY_PERMALINK: 'copy_permalink', @@ -165,6 +166,16 @@ export const useExploreAdditionalActionsMenu = ( [latestQueryFormData], ); + const exportExcel = useCallback( + () => + exportChart({ + formData: latestQueryFormData, + resultType: 'results', + resultFormat: 'xlsx', + }), + [latestQueryFormData], + ); + const copyLink = useCallback(async () => { try { if (!latestQueryFormData) { @@ -199,6 +210,11 @@ export const useExploreAdditionalActionsMenu = ( setIsDropdownVisible(false); setOpenSubmenus([]); + break; + case MENU_KEYS.EXPORT_TO_XLSX: + exportExcel(); + setIsDropdownVisible(false); + setOpenSubmenus([]); break; case MENU_KEYS.DOWNLOAD_AS_IMAGE: downloadAsImage( @@ -312,6 +328,12 @@ export const useExploreAdditionalActionsMenu = ( > {t('Download as image')} + } + > + {t('Export to Excel')} + diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 152383e0c66d..0d0758819ed0 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -46,7 +46,7 @@ from superset.extensions import event_logger from superset.utils.async_query_manager import AsyncQueryTokenException from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser -from superset.views.base import CsvResponse, generate_download_headers +from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse from superset.views.base_api import statsd_metrics if TYPE_CHECKING: @@ -353,24 +353,34 @@ def _send_chart_response( if result_type == ChartDataResultType.POST_PROCESSED: result = apply_post_process(result, form_data, datasource) - if result_format == ChartDataResultFormat.CSV: - # Verify user has permission to export CSV file + if result_format in ChartDataResultFormat.table_like(): + # Verify user has permission to export file if not security_manager.can_access("can_csv", "Superset"): return self.response_403() if not result["queries"]: return self.response_400(_("Empty query result")) + is_csv_format = result_format == ChartDataResultFormat.CSV + if len(result["queries"]) == 1: - # return single query results csv format + # return single query results data = result["queries"][0]["data"] - return CsvResponse(data, headers=generate_download_headers("csv")) + if is_csv_format: + return CsvResponse(data, headers=generate_download_headers("csv")) + + return XlsxResponse(data, headers=generate_download_headers("xlsx")) + + # return multi-query results bundled as a zip file + def _process_data(query_data: Any) -> Any: + if result_format == ChartDataResultFormat.CSV: + encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8") + return query_data.encode(encoding) + return query_data - # return multi-query csv results bundled as a zip file - encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8") files = { - f"query_{idx + 1}.csv": result["data"].encode(encoding) - for idx, result in enumerate(result["queries"]) + f"query_{idx + 1}.{result_format}": _process_data(query["data"]) + for idx, query in enumerate(result["queries"]) } return Response( create_zip(files), diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py index ea31d4f13817..659a64015937 100644 --- a/superset/common/chart_data.py +++ b/superset/common/chart_data.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from enum import Enum +from typing import Set class ChartDataResultFormat(str, Enum): @@ -24,6 +25,11 @@ class ChartDataResultFormat(str, Enum): CSV = "csv" JSON = "json" + XLSX = "xlsx" + + @classmethod + def table_like(cls) -> Set["ChartDataResultFormat"]: + return {cls.CSV} | {cls.XLSX} class ChartDataResultType(str, Enum): diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index e6fa964e4d7b..77ca69fcf6f0 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -46,7 +46,7 @@ from superset.extensions import cache_manager, security_manager from superset.models.helpers import QueryResult from superset.models.sql_lab import Query -from superset.utils import csv +from superset.utils import csv, excel from superset.utils.cache import generate_cache_key, set_and_log_cache from superset.utils.core import ( DatasourceType, @@ -446,15 +446,20 @@ def processing_time_offsets( # pylint: disable=too-many-locals,too-many-stateme return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys) def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]: - if self._query_context.result_format == ChartDataResultFormat.CSV: + if self._query_context.result_format in ChartDataResultFormat.table_like(): include_index = not isinstance(df.index, pd.RangeIndex) columns = list(df.columns) verbose_map = self._qc_datasource.data.get("verbose_map", {}) if verbose_map: df.columns = [verbose_map.get(column, column) for column in columns] - result = csv.df_to_escaped_csv( - df, index=include_index, **config["CSV_EXPORT"] - ) + + result = None + if self._query_context.result_format == ChartDataResultFormat.CSV: + result = csv.df_to_escaped_csv( + df, index=include_index, **config["CSV_EXPORT"] + ) + elif self._query_context.result_format == ChartDataResultFormat.XLSX: + result = excel.df_to_excel(df, **config["EXCEL_EXPORT"]) return result or "" return df.to_dict(orient="records") diff --git a/superset/config.py b/superset/config.py index 922d4a981f22..3aac5947189d 100644 --- a/superset/config.py +++ b/superset/config.py @@ -748,6 +748,11 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # note: index option should not be overridden CSV_EXPORT = {"encoding": "utf-8"} +# Excel Options: key/value pairs that will be passed as argument to DataFrame.to_excel +# method. +# note: index option should not be overridden +EXCEL_EXPORT = {"encoding": "utf-8"} + # --------------------------------------------------- # Time grain configurations # --------------------------------------------------- diff --git a/superset/utils/excel.py b/superset/utils/excel.py new file mode 100644 index 000000000000..1f68031b6497 --- /dev/null +++ b/superset/utils/excel.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import io +from typing import Any + +import pandas as pd + + +def df_to_excel(df: pd.DataFrame, **kwargs: Any) -> Any: + output = io.BytesIO() + # pylint: disable=abstract-class-instantiated + with pd.ExcelWriter(output, engine="xlsxwriter") as writer: + df.to_excel(writer, **kwargs) + + return output.getvalue() diff --git a/superset/views/base.py b/superset/views/base.py index 515384082299..ebccd0684b54 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -701,6 +701,17 @@ class CsvResponse(Response): default_mimetype = "text/csv" +class XlsxResponse(Response): + """ + Override Response to use xlsx mimetype + """ + + charset = "utf-8" + default_mimetype = ( + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + ) + + def bind_field( _: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any] ) -> Field: diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 164fb0ca6c72..66151362ff1d 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -255,6 +255,16 @@ def test_empty_request_with_csv_result_format(self): rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 400 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_empty_request_with_excel_result_format(self): + """ + Chart data API: Test empty chart data with Excel result format + """ + self.query_context_payload["result_format"] = "xlsx" + self.query_context_payload["queries"] = [] + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 400 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_csv_result_format(self): """ @@ -265,6 +275,17 @@ def test_with_csv_result_format(self): assert rv.status_code == 200 assert rv.mimetype == "text/csv" + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_excel_result_format(self): + """ + Chart data API: Test chart data with Excel result format + """ + self.query_context_payload["result_format"] = "xlsx" + mimetype = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 200 + assert rv.mimetype == mimetype + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_multi_query_csv_result_format(self): """ @@ -280,6 +301,21 @@ def test_with_multi_query_csv_result_format(self): zipfile = ZipFile(BytesIO(rv.data), "r") assert zipfile.namelist() == ["query_1.csv", "query_2.csv"] + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_multi_query_excel_result_format(self): + """ + Chart data API: Test chart data with multi-query Excel result format + """ + self.query_context_payload["result_format"] = "xlsx" + self.query_context_payload["queries"].append( + self.query_context_payload["queries"][0] + ) + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 200 + assert rv.mimetype == "application/zip" + zipfile = ZipFile(BytesIO(rv.data), "r") + assert zipfile.namelist() == ["query_1.xlsx", "query_2.xlsx"] + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self): """ @@ -292,6 +328,18 @@ def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self): rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") assert rv.status_code == 403 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_with_excel_result_format_when_actor_not_permitted_for_excel__403(self): + """ + Chart data API: Test chart data with Excel result format + """ + self.logout() + self.login(username="gamma_no_csv") + self.query_context_payload["result_format"] = "xlsx" + + rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") + assert rv.status_code == 403 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_with_row_limit_and_offset__row_limit_and_offset_were_applied(self): """