Skip to content

Commit

Permalink
fix: sqlparse fallback for formatting queries (apache#30578)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Oct 11, 2024
1 parent 9a2b1a5 commit 47c1e09
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 35 deletions.
108 changes: 88 additions & 20 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from typing import Any, Generic, TypeVar

import sqlglot
import sqlparse
from deprecation import deprecated
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
Expand Down Expand Up @@ -138,24 +140,22 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
Base class for SQL statements.
The class can be instantiated with a string representation of the script or, for
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
which will split a script in multiple already parsed statements.
The class should be instantiated with a string representation of the script and, for
efficiency reasons, optionally with a pre-parsed AST. This is useful with
`sqlglot.parse`, which will split a script in multiple already parsed statements.
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
spec.
"""

def __init__(
self,
statement: str | InternalRepresentation,
statement: str,
engine: str,
ast: InternalRepresentation | None = None,
):
self._parsed: InternalRepresentation = (
self._parse_statement(statement, engine)
if isinstance(statement, str)
else statement
)
self._sql = statement
self._parsed = ast or self._parse_statement(statement, engine)
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)

Expand Down Expand Up @@ -239,11 +239,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):

def __init__(
self,
statement: str | exp.Expression,
statement: str,
engine: str,
ast: exp.Expression | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine)
super().__init__(statement, engine, ast)

@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
Expand Down Expand Up @@ -275,11 +276,47 @@ def split_script(
script: str,
engine: str,
) -> list[SQLStatement]:
return [
cls(statement, engine)
for statement in cls._parse(script, engine)
if statement
]
if engine in SQLGLOT_DIALECTS:
try:
return [
cls(ast.sql(), engine, ast)
for ast in cls._parse(script, engine)
if ast
]
except ValueError:
# `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES
# FROM`). In this case, we rely on the tokenizer to generate the
# statements.
pass

# When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly
# generate the SQL of each statement, so we tokenize the script and split it
# based on the location of semi-colons.
statements = []
start = 0
remainder = script

try:
tokens = sqlglot.tokenize(script)
except sqlglot.errors.TokenError as ex:
raise SupersetParseError(
script,
engine,
message="Unable to tokenize script",
) from ex

for token in tokens:
if token.token_type == sqlglot.TokenType.SEMICOLON:
statement, start = script[start : token.start], token.end + 1
ast = cls._parse(statement, engine)[0]
statements.append(cls(statement.strip(), engine, ast))
remainder = script[start:]

if remainder.strip():
ast = cls._parse(remainder, engine)[0]
statements.append(cls(remainder.strip(), engine, ast))

return statements

@classmethod
def _parse_statement(
Expand Down Expand Up @@ -349,8 +386,34 @@ def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
write = Dialect.get_or_raise(self._dialect)
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
if self._dialect:
try:
write = Dialect.get_or_raise(self._dialect)
return write.generate(
self._parsed,
copy=False,
comments=comments,
pretty=True,
)
except ValueError:
pass

return self._fallback_formatting()

@deprecated(deprecated_in="4.0", removed_in="5.0")
def _fallback_formatting(self) -> str:
"""
Format SQL without a specific dialect.
Reformatting SQL using the generic sqlglot dialect is known to break queries.
For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which
breaks the query for Firebolt. To avoid this, we use sqlparse for formatting
when the dialect is not known.
In 5.0 we should remove `sqlparse`, and the method should return the query
unmodified.
"""
return sqlparse.format(self._sql, reindent=True, keyword_case="upper")

def get_settings(self) -> dict[str, str | bool]:
"""
Expand Down Expand Up @@ -456,7 +519,9 @@ def split_script(
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
for more information.
"""
return [cls(statement, engine) for statement in split_kql(script)]
return [
cls(statement, engine, statement.strip()) for statement in split_kql(script)
]

@classmethod
def _parse_statement(
Expand Down Expand Up @@ -498,7 +563,7 @@ def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
"""
return self._parsed
return self._sql.strip()

def get_settings(self) -> dict[str, str | bool]:
"""
Expand Down Expand Up @@ -548,6 +613,9 @@ def __init__(
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL script.
Note that even though KQL is very different from SQL, multiple statements are
still separated by semi-colons.
"""
return ";\n".join(statement.format(comments) for statement in self.statements)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/sql_lab/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_format_sql_request(self):
"/api/v1/sqllab/format_sql/",
json=data,
)
success_resp = {"result": "SELECT\n 1\nFROM my_table"}
success_resp = {"result": "SELECT 1\nFROM my_table"}
resp_data = json.loads(rv.data.decode("utf-8"))
self.assertDictEqual(resp_data, success_resp) # noqa: PT009
assert rv.status_code == 200
Expand Down
16 changes: 2 additions & 14 deletions tests/unit_tests/db_engine_specs/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec):
latest_partition=False,
cols=cols,
)
assert (
sql
== """SELECT
a
FROM my_table
LIMIT ?
OFFSET ?"""
)
assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?"

sql = NoLimitDBEngineSpec.select_star(
database=database,
Expand All @@ -260,12 +253,7 @@ class NoLimitDBEngineSpec(BaseEngineSpec):
latest_partition=False,
cols=cols,
)
assert (
sql
== """SELECT
a
FROM my_table"""
)
assert sql == "SELECT a\nFROM my_table"


def test_extra_table_metadata(mocker: MockerFixture) -> None:
Expand Down
34 changes: 34 additions & 0 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,40 @@ def test_extract_tables_show_tables_from() -> None:
)


def test_format_show_tables() -> None:
"""
Test format when `ast.sql()` raises an exception.
In that case sqlparse should be used instead.
"""
assert (
SQLScript("SHOW TABLES FROM s1 like '%order%'", "mysql").format()
== "SHOW TABLES FROM s1 LIKE '%order%'"
)


def test_format_no_dialect() -> None:
"""
Test format with an engine that has no corresponding dialect.
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "firebolt").format()
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
)


def test_split_no_dialect() -> None:
"""
Test the statement split when the engine has no corresponding dialect.
"""
sql = "SELECT col FROM t WHERE col NOT IN (1, 2); SELECT * FROM t; SELECT foo"
statements = SQLScript(sql, "firebolt").statements
assert len(statements) == 3
assert statements[0]._sql == "SELECT col FROM t WHERE col NOT IN (1, 2)"
assert statements[1]._sql == "SELECT * FROM t"
assert statements[2]._sql == "SELECT foo"


def test_extract_tables_show_columns_from() -> None:
"""
Test `SHOW COLUMNS FROM`.
Expand Down

0 comments on commit 47c1e09

Please sign in to comment.