Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sql_query): validate if the query is not malicious #1568

Merged
merged 7 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

from pandasai.dataframe.base import DataFrame
from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError
from pandasai.helpers.path import find_project_root
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name

from ..constants import (
LOCAL_SOURCE_TYPES,
Expand Down Expand Up @@ -197,6 +197,9 @@
load_function = self._get_loader_function(source_type)

try:
if not is_sql_query_safe(formatted_query):
raise MaliciousQueryError("Query is not safe to execute.")

Check warning on line 201 in pandasai/data_loader/loader.py

View check run for this annotation

Codecov / codecov/patch

pandasai/data_loader/loader.py#L200-L201

Added lines #L200 - L201 were not covered by tests
gventuri marked this conversation as resolved.
Show resolved Hide resolved

return load_function(connection_info, formatted_query, params)
except Exception as e:
raise RuntimeError(
Expand Down
70 changes: 70 additions & 0 deletions pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import re

import sqlglot


def sanitize_sql_table_name(filepath: str) -> str:
# Extract the file name without extension
Expand All @@ -14,3 +16,71 @@
sanitized_name = sanitized_name[:max_length]

return sanitized_name


def is_sql_query_safe(query: str) -> bool:
try:
# List of infected keywords to block (you can add more)
infected_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bEXEC\b",
r"\bALTER\b",
r"\bCREATE\b",
r"\bMERGE\b",
r"\bREPLACE\b",
r"\bTRUNCATE\b",
r"\bLOAD\b",
r"\bGRANT\b",
r"\bREVOKE\b",
r"\bCALL\b",
r"\bEXECUTE\b",
r"\bSHOW\b",
r"\bDESCRIBE\b",
r"\bEXPLAIN\b",
r"\bUSE\b",
r"\bSET\b",
r"\bDECLARE\b",
r"\bOPEN\b",
r"\bFETCH\b",
r"\bCLOSE\b",
r"\bSLEEP\b",
r"\bBENCHMARK\b",
r"\bDATABASE\b",
r"\bUSER\b",
r"\bCURRENT_USER\b",
r"\bSESSION_USER\b",
r"\bSYSTEM_USER\b",
r"\bVERSION\b",
r"\b@@VERSION\b",
r"--",
r"/\*.*\*/", # Block comments and inline comments
]
# Parse the query to extract its structure
parsed = sqlglot.parse_one(query)

# Ensure the main query is SELECT or WITH
if parsed.key == "SELECT":
gventuri marked this conversation as resolved.
Show resolved Hide resolved
return False

Check warning on line 66 in pandasai/helpers/sql_sanitizer.py

View check run for this annotation

Codecov / codecov/patch

pandasai/helpers/sql_sanitizer.py#L66

Added line #L66 was not covered by tests

# Check for infected keywords in the main query
if any(
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
):
return False

# Check for infected keywords in subqueries
for subquery in parsed.find_all(sqlglot.exp.Subquery):
subquery_sql = subquery.sql() # Get the SQL of the subquery
if any(
re.search(keyword, subquery_sql, re.IGNORECASE)
for keyword in infected_keywords
):
return False

Check warning on line 81 in pandasai/helpers/sql_sanitizer.py

View check run for this annotation

Codecov / codecov/patch

pandasai/helpers/sql_sanitizer.py#L81

Added line #L81 was not covered by tests

return True

except sqlglot.errors.ParseError:
return False
70 changes: 69 additions & 1 deletion tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name


class TestSqlSanitizer:
Expand All @@ -17,3 +17,71 @@ def test_filename_with_long_name(self):
filepath = "/path/to/" + "a" * 100 + ".csv"
expected = "a" * 64
assert sanitize_sql_table_name(filepath) == expected

def test_safe_select_query(self):
query = "SELECT * FROM users WHERE username = 'admin';"
assert is_sql_query_safe(query)

def test_safe_with_query(self):
query = "WITH user_data AS (SELECT * FROM users) SELECT * FROM user_data;"
assert is_sql_query_safe(query)

def test_unsafe_insert_query(self):
query = "INSERT INTO users (username, password) VALUES ('admin', 'password');"
assert not is_sql_query_safe(query)

def test_unsafe_update_query(self):
query = "UPDATE users SET password = 'newpassword' WHERE username = 'admin';"
assert not is_sql_query_safe(query)

def test_unsafe_delete_query(self):
query = "DELETE FROM users WHERE username = 'admin';"
assert not is_sql_query_safe(query)

def test_unsafe_drop_query(self):
query = "DROP TABLE users;"
assert not is_sql_query_safe(query)

def test_unsafe_alter_query(self):
query = "ALTER TABLE users ADD COLUMN age INT;"
assert not is_sql_query_safe(query)

def test_unsafe_create_query(self):
query = "CREATE TABLE users (id INT, username VARCHAR(50));"
assert not is_sql_query_safe(query)

def test_safe_select_with_comment(self):
query = "SELECT * FROM users WHERE username = 'admin' -- comment"
assert not is_sql_query_safe(query) # Blocked by comment detection

def test_safe_select_with_inline_comment(self):
query = "SELECT * FROM users /* inline comment */ WHERE username = 'admin';"
assert not is_sql_query_safe(query) # Blocked by comment detection

def test_unsafe_query_with_subquery(self):
query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders);"
assert is_sql_query_safe(query) # No dangerous keyword in main or subquery

def test_unsafe_query_with_subquery_insert(self):
query = (
"SELECT * FROM users WHERE id IN (INSERT INTO orders (user_id) VALUES (1));"
)
assert not is_sql_query_safe(query) # Subquery contains INSERT, blocked

def test_invalid_sql(self):
query = "INVALID SQL QUERY"
assert not is_sql_query_safe(query) # Invalid query should return False

def test_safe_query_with_multiple_keywords(self):
query = "SELECT name FROM users WHERE username = 'admin' AND age > 30;"
assert is_sql_query_safe(query) # Safe query with no dangerous keyword

def test_safe_query_with_subquery(self):
query = "SELECT name FROM users WHERE username IN (SELECT username FROM users WHERE age > 30);"
assert is_sql_query_safe(
query
) # Safe query with subquery, no dangerous keyword


if __name__ == "__main__":
unittest.main()
Loading