From 357986103b211783455768ad33a4366bec04c578 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Sat, 19 Aug 2023 15:49:15 +0100 Subject: [PATCH] fix: CTE queries with non-SELECT statements (#25014) --- superset/sql_parse.py | 55 ++++++++++++++++++++ tests/unit_tests/sql_parse_tests.py | 81 +++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/superset/sql_parse.py b/superset/sql_parse.py index c45a3a354452..2a283b81f0fc 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -217,9 +217,53 @@ def tables(self) -> set[Table]: def limit(self) -> Optional[int]: return self._limit + def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]: + if "with" not in parsed: + return [] + return parsed["with"].get("cte_tables", []) + + def _check_cte_is_select(self, oxide_parse: list[dict[str, Any]]) -> bool: + """ + Check if a oxide parsed CTE contains only SELECT statements + + :param oxide_parse: parsed CTE + :return: True if CTE is a SELECT statement + """ + for query in oxide_parse: + parsed_query = query["Query"] + cte_tables = self._get_cte_tables(parsed_query) + for cte_table in cte_tables: + is_select = all( + key == "Select" for key in cte_table["query"]["body"].keys() + ) + if not is_select: + return False + return True + def is_select(self) -> bool: # make sure we strip comments; prevents a bug with comments in the CTE parsed = sqlparse.parse(self.strip_comments()) + + # Check if this is a CTE + if parsed[0].is_group and parsed[0][0].ttype == Keyword.CTE: + if sqloxide_parse is not None: + try: + if not self._check_cte_is_select( + sqloxide_parse(self.strip_comments(), dialect="ansi") + ): + return False + except ValueError: + # sqloxide was not able to parse the query, so let's continue with + # sqlparse + pass + inner_cte = self.get_inner_cte_expression(parsed[0].tokens) or [] + # Check if the inner CTE is a not a SELECT + if any(token.ttype == DDL for token in inner_cte) or any( + token.ttype == DML and token.normalized != "SELECT" + for token in inner_cte + ): + return False + if parsed[0].get_type() == "SELECT": return True @@ -241,6 +285,17 @@ def is_select(self) -> bool: token.ttype == DML and token.normalized == "SELECT" for token in parsed[0] ) + def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]: + for token in tokens: + if self._is_identifier(token): + for identifier_token in token.tokens: + if ( + isinstance(identifier_token, Parenthesis) + and identifier_token.is_group + ): + return identifier_token.tokens + return None + def is_valid_ctas(self) -> bool: parsed = sqlparse.parse(self.strip_comments()) return parsed[-1].get_type() == "SELECT" diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index e00dc3166e02..7d8839198c43 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1029,6 +1029,87 @@ def test_cte_is_select_lowercase() -> None: assert sql.is_select() +def test_cte_insert_is_not_select() -> None: + """ + Some CTEs with lowercase select are not correctly identified as SELECTS. + """ + sql = ParsedQuery( + """WITH foo AS( + INSERT INTO foo (id) VALUES (1) RETURNING 1 + ) select * FROM foo f""" + ) + assert sql.is_select() is False + + +def test_cte_delete_is_not_select() -> None: + """ + Some CTEs with lowercase select are not correctly identified as SELECTS. + """ + sql = ParsedQuery( + """WITH foo AS( + DELETE FROM foo RETURNING * + ) select * FROM foo f""" + ) + assert sql.is_select() is False + + +def test_cte_is_not_select_lowercase() -> None: + """ + Some CTEs with lowercase select are not correctly identified as SELECTS. + """ + sql = ParsedQuery( + """WITH foo AS( + insert into foo (id) values (1) RETURNING 1 + ) select * FROM foo f""" + ) + assert sql.is_select() is False + + +def test_cte_with_multiple_selects() -> None: + sql = ParsedQuery( + "WITH a AS ( select * from foo1 ), b as (select * from foo2) SELECT * FROM a;" + ) + assert sql.is_select() + + +def test_cte_with_multiple_with_non_select() -> None: + sql = ParsedQuery( + """WITH a AS ( + select * from foo1 + ), b as ( + update foo2 set id=2 + ) SELECT * FROM a""" + ) + assert sql.is_select() is False + sql = ParsedQuery( + """WITH a AS ( + update foo2 set name=2 + ), + b as ( + select * from foo1 + ) SELECT * FROM a""" + ) + assert sql.is_select() is False + sql = ParsedQuery( + """WITH a AS ( + update foo2 set name=2 + ), + b as ( + update foo1 set name=2 + ) SELECT * FROM a""" + ) + assert sql.is_select() is False + sql = ParsedQuery( + """WITH a AS ( + INSERT INTO foo (id) VALUES (1) + ), + b as ( + select 1 + ) SELECT * FROM a""" + ) + assert sql.is_select() is False + + def test_unknown_select() -> None: """ Test that `is_select` works when sqlparse fails to identify the type.