Skip to content

Commit

Permalink
feat: create dtype option for csv upload (#23716)
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho committed Apr 24, 2023
1 parent 4873c09 commit 71106cf
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 2 deletions.
40 changes: 40 additions & 0 deletions superset/db_engine_specs/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
import re
from typing import Any, Dict, Optional, Pattern, Tuple

import pandas as pd
from flask_babel import gettext as __
from sqlalchemy.types import NVARCHAR

from superset.db_engine_specs.base import BasicParametersMixin
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
from superset.errors import SupersetErrorType
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.sql_parse import Table

logger = logging.getLogger()

Expand Down Expand Up @@ -96,6 +100,42 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
),
}

@classmethod
def df_to_sql(
cls,
database: Database,
table: Table,
df: pd.DataFrame,
to_sql_kwargs: Dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
For regular engines this calls the `pandas.DataFrame.to_sql` method.
Overrides the base class to allow for pandas string types to be
used as nvarchar(max) columns, as redshift does not support
text data types.
Note this method does not create metadata for the table.
:param database: The database to upload the data to
:param table: The table to upload the data to
:param df: The dataframe with data to be uploaded
:param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method
"""
to_sql_kwargs = to_sql_kwargs or {}
to_sql_kwargs["dtype"] = {
# uses the max size for redshift nvarchar(65335)
# the default object and string types create a varchar(256)
col_name: NVARCHAR(length=65535)
for col_name, type in zip(df.columns, df.dtypes)
if isinstance(type, pd.StringDtype)
}

super().df_to_sql(
df=df, database=database, table=table, to_sql_kwargs=to_sql_kwargs
)

@staticmethod
def _mutate_label(label: str) -> str:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@
{{ lib.render_field(form.overwrite_duplicate, begin_sep_label, end_sep_label, begin_sep_field,
end_sep_field) }}
</tr>
<tr>
{{ lib.render_field(form.dtype, begin_sep_label, end_sep_label, begin_sep_field,
end_sep_field) }}
</tr>
{% endcall %}
{% call csv_macros.render_collapsable_form_group("accordion3", "Rows") %}
<tr>
Expand Down
10 changes: 10 additions & 0 deletions superset/views/database/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ class CsvToDatabaseForm(UploadToDatabaseForm):
get_pk=lambda a: a.id,
get_label=lambda a: a.database_name,
)
dtype = StringField(
_("Column Data Types"),
description=_(
"A dictionary with column names and their data types"
" if you need to change the defaults."
' Example: {"user_id":"integer"}'
),
validators=[Optional()],
widget=BS3TextFieldWidget(),
)
schema = StringField(
_("Schema"),
description=_("Select a schema if the database supports this"),
Expand Down
3 changes: 3 additions & 0 deletions superset/views/database/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import io
import json
import os
import tempfile
import zipfile
Expand Down Expand Up @@ -189,6 +190,7 @@ def form_post(self, form: CsvToDatabaseForm) -> Response:
delimiter_input = form.otherInput.data

try:
kwargs = {"dtype": json.loads(form.dtype.data)} if form.dtype.data else {}
df = pd.concat(
pd.read_csv(
chunksize=1000,
Expand All @@ -208,6 +210,7 @@ def form_post(self, form: CsvToDatabaseForm) -> Response:
skip_blank_lines=form.skip_blank_lines.data,
skipinitialspace=form.skip_initial_space.data,
skiprows=form.skiprows.data,
**kwargs,
)
)

Expand Down
44 changes: 42 additions & 2 deletions tests/integration_tests/csv_upload_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import os
import shutil
from typing import Dict, Optional
from typing import Dict, Optional, Union

from unittest import mock

Expand Down Expand Up @@ -129,7 +129,12 @@ def get_upload_db():
return db.session.query(Database).filter_by(database_name=CSV_UPLOAD_DATABASE).one()


def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] = None):
def upload_csv(
filename: str,
table_name: str,
extra: Optional[Dict[str, str]] = None,
dtype: Union[str, None] = None,
):
csv_upload_db_id = get_upload_db().id
schema = utils.get_example_default_schema()
form_data = {
Expand All @@ -145,6 +150,8 @@ def upload_csv(filename: str, table_name: str, extra: Optional[Dict[str, str]] =
form_data["schema"] = schema
if extra:
form_data.update(extra)
if dtype:
form_data["dtype"] = dtype
return get_resp(test_client, "/csvtodatabaseview/form", data=form_data)


Expand Down Expand Up @@ -386,6 +393,39 @@ def test_import_csv(mock_event_logger):
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == [("john", 1, "x"), ("paul", 2, None)]

# cleanup
with get_upload_db().get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {full_table_name}")

# with dtype
upload_csv(
CSV_FILENAME1,
CSV_UPLOAD_TABLE,
dtype='{"a": "string", "b": "float64"}',
)

# you can change the type to something compatible, like an object to string
# or an int to a float
# file upload should work as normal
with test_db.get_sqla_engine_with_context() as engine:
data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall()
assert data == [("john", 1), ("paul", 2)]

# cleanup
with get_upload_db().get_sqla_engine_with_context() as engine:
engine.execute(f"DROP TABLE {full_table_name}")

# with dtype - wrong type
resp = upload_csv(
CSV_FILENAME1,
CSV_UPLOAD_TABLE,
dtype='{"a": "int"}',
)

# you cannot pass an incompatible dtype
fail_msg = f"Unable to upload CSV file {escaped_double_quotes(CSV_FILENAME1)} to table {escaped_double_quotes(CSV_UPLOAD_TABLE)}"
assert fail_msg in resp


@pytest.mark.usefixtures("setup_csv_upload_with_context")
@pytest.mark.usefixtures("create_excel_files")
Expand Down
61 changes: 61 additions & 0 deletions tests/integration_tests/db_engine_specs/redshift_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest.mock as mock
from textwrap import dedent

import numpy as np
import pandas as pd
from sqlalchemy.types import NVARCHAR

from superset.db_engine_specs.redshift import RedshiftEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql_parse import Table
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.test_app import app


class TestRedshiftDbEngineSpec(TestDbEngineSpec):
Expand Down Expand Up @@ -183,3 +190,57 @@ def test_extract_errors(self):
},
)
]

def test_df_to_sql_no_dtype(self):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
table_name = "foobar"
data = [
("foo", "bar", pd.NA, None),
("foo", "bar", pd.NA, True),
("foo", "bar", pd.NA, None),
]
numpy_dtype = [
("id", "object"),
("value", "object"),
("num", "object"),
("bool", "object"),
]
column_names = ["id", "value", "num", "bool"]

test_array = np.array(data, dtype=numpy_dtype)

df = pd.DataFrame(test_array, columns=column_names)
df.to_sql = mock.MagicMock()

with app.app_context():
RedshiftEngineSpec.df_to_sql(
mock_database, Table(table=table_name), df, to_sql_kwargs={}
)

assert df.to_sql.call_args[1]["dtype"] == {}

def test_df_to_sql_with_string_dtype(self):
mock_database = mock.MagicMock()
mock_database.get_df.return_value.empty = False
table_name = "foobar"
data = [
("foo", "bar", pd.NA, None),
("foo", "bar", pd.NA, True),
("foo", "bar", pd.NA, None),
]
column_names = ["id", "value", "num", "bool"]

df = pd.DataFrame(data, columns=column_names)
df = df.astype(dtype={"value": "string"})
df.to_sql = mock.MagicMock()

with app.app_context():
RedshiftEngineSpec.df_to_sql(
mock_database, Table(table=table_name), df, to_sql_kwargs={}
)

# varchar string length should be 65535
dtype = df.to_sql.call_args[1]["dtype"]
assert isinstance(dtype["value"], NVARCHAR)
assert dtype["value"].length == 65535

0 comments on commit 71106cf

Please sign in to comment.