Skip to content

Commit

Permalink
fix(sql_query): validate if the query is not malicious
Browse files Browse the repository at this point in the history
  • Loading branch information
ArslanSaleem committed Jan 30, 2025
1 parent 01bf53e commit 8d50ee3
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 3 deletions.
7 changes: 5 additions & 2 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from pandasai.dataframe.base import DataFrame
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name

from ..constants import (
LOCAL_SOURCE_TYPES,
Expand Down Expand Up @@ -197,6 +197,9 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
load_function = self._get_loader_function(source_type)

try:
if not is_sql_query_safe(formatted_query):
raise MaliciousQueryError("Query is not safe to execute.")

Check warning on line 201 in pandasai/data_loader/loader.py

View check run for this annotation

Codecov / codecov/patch

pandasai/data_loader/loader.py#L200-L201

Added lines #L200 - L201 were not covered by tests

return load_function(connection_info, formatted_query, params)
except Exception as e:
raise RuntimeError(
Expand Down
70 changes: 70 additions & 0 deletions pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import re

import sqlglot


def sanitize_sql_table_name(filepath: str) -> str:
# Extract the file name without extension
Expand All @@ -14,3 +16,71 @@ def sanitize_sql_table_name(filepath: str) -> str:
sanitized_name = sanitized_name[:max_length]

return sanitized_name


def is_sql_query_safe(query: str) -> bool:
try:
# List of infected keywords to block (you can add more)
infected_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bEXEC\b",
r"\bALTER\b",
r"\bCREATE\b",
r"\bMERGE\b",
r"\bREPLACE\b",
r"\bTRUNCATE\b",
r"\bLOAD\b",
r"\bGRANT\b",
r"\bREVOKE\b",
r"\bCALL\b",
r"\bEXECUTE\b",
r"\bSHOW\b",
r"\bDESCRIBE\b",
r"\bEXPLAIN\b",
r"\bUSE\b",
r"\bSET\b",
r"\bDECLARE\b",
r"\bOPEN\b",
r"\bFETCH\b",
r"\bCLOSE\b",
r"\bSLEEP\b",
r"\bBENCHMARK\b",
r"\bDATABASE\b",
r"\bUSER\b",
r"\bCURRENT_USER\b",
r"\bSESSION_USER\b",
r"\bSYSTEM_USER\b",
r"\bVERSION\b",
r"\b@@VERSION\b",
r"--",
r"/\*.*\*/", # Block comments and inline comments
]
# Parse the query to extract its structure
parsed = sqlglot.parse_one(query)

# Ensure the main query is SELECT or WITH
if parsed.key == "SELECT":
return False

Check warning on line 66 in pandasai/helpers/sql_sanitizer.py

View check run for this annotation

Codecov / codecov/patch

pandasai/helpers/sql_sanitizer.py#L66

Added line #L66 was not covered by tests

# Check for infected keywords in the main query
if any(
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
):
return False

# Check for infected keywords in subqueries
for subquery in parsed.find_all(sqlglot.exp.Subquery):
subquery_sql = subquery.sql() # Get the SQL of the subquery
if any(
re.search(keyword, subquery_sql, re.IGNORECASE)
for keyword in infected_keywords
):
return False

Check warning on line 81 in pandasai/helpers/sql_sanitizer.py

View check run for this annotation

Codecov / codecov/patch

pandasai/helpers/sql_sanitizer.py#L81

Added line #L81 was not covered by tests

return True

except sqlglot.errors.ParseError:
return False
70 changes: 69 additions & 1 deletion tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name


class TestSqlSanitizer:
Expand All @@ -17,3 +17,71 @@ def test_filename_with_long_name(self):
filepath = "/path/to/" + "a" * 100 + ".csv"
expected = "a" * 64
assert sanitize_sql_table_name(filepath) == expected

def test_safe_select_query(self):
query = "SELECT * FROM users WHERE username = 'admin';"
assert is_sql_query_safe(query)

def test_safe_with_query(self):
query = "WITH user_data AS (SELECT * FROM users) SELECT * FROM user_data;"
assert is_sql_query_safe(query)

def test_unsafe_insert_query(self):
query = "INSERT INTO users (username, password) VALUES ('admin', 'password');"
assert not is_sql_query_safe(query)

def test_unsafe_update_query(self):
query = "UPDATE users SET password = 'newpassword' WHERE username = 'admin';"
assert not is_sql_query_safe(query)

def test_unsafe_delete_query(self):
query = "DELETE FROM users WHERE username = 'admin';"
assert not is_sql_query_safe(query)

def test_unsafe_drop_query(self):
query = "DROP TABLE users;"
assert not is_sql_query_safe(query)

def test_unsafe_alter_query(self):
query = "ALTER TABLE users ADD COLUMN age INT;"
assert not is_sql_query_safe(query)

def test_unsafe_create_query(self):
query = "CREATE TABLE users (id INT, username VARCHAR(50));"
assert not is_sql_query_safe(query)

def test_safe_select_with_comment(self):
query = "SELECT * FROM users WHERE username = 'admin' -- comment"
assert not is_sql_query_safe(query) # Blocked by comment detection

def test_safe_select_with_inline_comment(self):
query = "SELECT * FROM users /* inline comment */ WHERE username = 'admin';"
assert not is_sql_query_safe(query) # Blocked by comment detection

def test_unsafe_query_with_subquery(self):
query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders);"
assert is_sql_query_safe(query) # No dangerous keyword in main or subquery

def test_unsafe_query_with_subquery_insert(self):
query = (
"SELECT * FROM users WHERE id IN (INSERT INTO orders (user_id) VALUES (1));"
)
assert not is_sql_query_safe(query) # Subquery contains INSERT, blocked

def test_invalid_sql(self):
query = "INVALID SQL QUERY"
assert not is_sql_query_safe(query) # Invalid query should return False

def test_safe_query_with_multiple_keywords(self):
query = "SELECT name FROM users WHERE username = 'admin' AND age > 30;"
assert is_sql_query_safe(query) # Safe query with no dangerous keyword

def test_safe_query_with_subquery(self):
query = "SELECT name FROM users WHERE username IN (SELECT username FROM users WHERE age > 30);"
assert is_sql_query_safe(
query
) # Safe query with subquery, no dangerous keyword


if __name__ == "__main__":
unittest.main()

0 comments on commit 8d50ee3

Please sign in to comment.