diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index 9a7611fba..6740ec691 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -9,9 +9,9 @@ from pandasai.dataframe.base import DataFrame from pandasai.dataframe.virtual_dataframe import VirtualDataFrame -from pandasai.exceptions import InvalidDataSourceType +from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError from pandasai.helpers.path import find_project_root -from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name +from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name from ..constants import ( LOCAL_SOURCE_TYPES, @@ -197,6 +197,9 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra load_function = self._get_loader_function(source_type) try: + if not is_sql_query_safe(formatted_query): + raise MaliciousQueryError("Query is not safe to execute.") + return load_function(connection_info, formatted_query, params) except Exception as e: raise RuntimeError( diff --git a/pandasai/helpers/sql_sanitizer.py b/pandasai/helpers/sql_sanitizer.py index 82b4306eb..2086f745a 100644 --- a/pandasai/helpers/sql_sanitizer.py +++ b/pandasai/helpers/sql_sanitizer.py @@ -1,6 +1,8 @@ import os import re +import sqlglot + def sanitize_sql_table_name(filepath: str) -> str: # Extract the file name without extension @@ -14,3 +16,71 @@ def sanitize_sql_table_name(filepath: str) -> str: sanitized_name = sanitized_name[:max_length] return sanitized_name + + +def is_sql_query_safe(query: str) -> bool: + try: + # List of infected keywords to block (you can add more) + infected_keywords = [ + r"\bINSERT\b", + r"\bUPDATE\b", + r"\bDELETE\b", + r"\bDROP\b", + r"\bEXEC\b", + r"\bALTER\b", + r"\bCREATE\b", + r"\bMERGE\b", + r"\bREPLACE\b", + r"\bTRUNCATE\b", + r"\bLOAD\b", + r"\bGRANT\b", + r"\bREVOKE\b", + r"\bCALL\b", + r"\bEXECUTE\b", + r"\bSHOW\b", + r"\bDESCRIBE\b", + r"\bEXPLAIN\b", + r"\bUSE\b", + r"\bSET\b", + r"\bDECLARE\b", + r"\bOPEN\b", + r"\bFETCH\b", + r"\bCLOSE\b", + r"\bSLEEP\b", + r"\bBENCHMARK\b", + r"\bDATABASE\b", + r"\bUSER\b", + r"\bCURRENT_USER\b", + r"\bSESSION_USER\b", + r"\bSYSTEM_USER\b", + r"\bVERSION\b", + r"\b@@VERSION\b", + r"--", + r"/\*.*\*/", # Block comments and inline comments + ] + # Parse the query to extract its structure + parsed = sqlglot.parse_one(query) + + # Ensure the main query is SELECT or WITH + if parsed.key == "SELECT": + return False + + # Check for infected keywords in the main query + if any( + re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords + ): + return False + + # Check for infected keywords in subqueries + for subquery in parsed.find_all(sqlglot.exp.Subquery): + subquery_sql = subquery.sql() # Get the SQL of the subquery + if any( + re.search(keyword, subquery_sql, re.IGNORECASE) + for keyword in infected_keywords + ): + return False + + return True + + except sqlglot.errors.ParseError: + return False diff --git a/tests/unit_tests/helpers/test_sql_sanitizer.py b/tests/unit_tests/helpers/test_sql_sanitizer.py index 5f4ab40fc..a572cc5f4 100644 --- a/tests/unit_tests/helpers/test_sql_sanitizer.py +++ b/tests/unit_tests/helpers/test_sql_sanitizer.py @@ -1,4 +1,4 @@ -from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name +from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name class TestSqlSanitizer: @@ -17,3 +17,71 @@ def test_filename_with_long_name(self): filepath = "/path/to/" + "a" * 100 + ".csv" expected = "a" * 64 assert sanitize_sql_table_name(filepath) == expected + + def test_safe_select_query(self): + query = "SELECT * FROM users WHERE username = 'admin';" + assert is_sql_query_safe(query) + + def test_safe_with_query(self): + query = "WITH user_data AS (SELECT * FROM users) SELECT * FROM user_data;" + assert is_sql_query_safe(query) + + def test_unsafe_insert_query(self): + query = "INSERT INTO users (username, password) VALUES ('admin', 'password');" + assert not is_sql_query_safe(query) + + def test_unsafe_update_query(self): + query = "UPDATE users SET password = 'newpassword' WHERE username = 'admin';" + assert not is_sql_query_safe(query) + + def test_unsafe_delete_query(self): + query = "DELETE FROM users WHERE username = 'admin';" + assert not is_sql_query_safe(query) + + def test_unsafe_drop_query(self): + query = "DROP TABLE users;" + assert not is_sql_query_safe(query) + + def test_unsafe_alter_query(self): + query = "ALTER TABLE users ADD COLUMN age INT;" + assert not is_sql_query_safe(query) + + def test_unsafe_create_query(self): + query = "CREATE TABLE users (id INT, username VARCHAR(50));" + assert not is_sql_query_safe(query) + + def test_safe_select_with_comment(self): + query = "SELECT * FROM users WHERE username = 'admin' -- comment" + assert not is_sql_query_safe(query) # Blocked by comment detection + + def test_safe_select_with_inline_comment(self): + query = "SELECT * FROM users /* inline comment */ WHERE username = 'admin';" + assert not is_sql_query_safe(query) # Blocked by comment detection + + def test_unsafe_query_with_subquery(self): + query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders);" + assert is_sql_query_safe(query) # No dangerous keyword in main or subquery + + def test_unsafe_query_with_subquery_insert(self): + query = ( + "SELECT * FROM users WHERE id IN (INSERT INTO orders (user_id) VALUES (1));" + ) + assert not is_sql_query_safe(query) # Subquery contains INSERT, blocked + + def test_invalid_sql(self): + query = "INVALID SQL QUERY" + assert not is_sql_query_safe(query) # Invalid query should return False + + def test_safe_query_with_multiple_keywords(self): + query = "SELECT name FROM users WHERE username = 'admin' AND age > 30;" + assert is_sql_query_safe(query) # Safe query with no dangerous keyword + + def test_safe_query_with_subquery(self): + query = "SELECT name FROM users WHERE username IN (SELECT username FROM users WHERE age > 30);" + assert is_sql_query_safe( + query + ) # Safe query with subquery, no dangerous keyword + + +if __name__ == "__main__": + unittest.main()