Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sql_query): validate if the query is not malicious #1568

Merged
merged 7 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion pandasai/data_loader/sql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import pandas as pd

from pandasai.dataframe.virtual_dataframe import VirtualDataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.exceptions import InvalidDataSourceType, MaliciousQueryError
from pandasai.helpers.sql_sanitizer import is_sql_query_safe

from ..constants import (
SUPPORTED_SOURCE_CONNECTORS,
Expand Down Expand Up @@ -36,6 +37,12 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra

formatted_query = self.query_builder.format_query(query)
load_function = self._get_loader_function(source_type)

if not is_sql_query_safe(formatted_query):
raise MaliciousQueryError(
"The SQL query is deemed unsafe and will not be executed."
)

try:
dataframe: pd.DataFrame = load_function(
connection_info, formatted_query, params
Expand Down
70 changes: 70 additions & 0 deletions pandasai/helpers/sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import re

import sqlglot


def sanitize_sql_table_name(filepath: str) -> str:
# Extract the file name without extension
Expand All @@ -14,3 +16,71 @@
sanitized_name = sanitized_name[:max_length]

return sanitized_name


def is_sql_query_safe(query: str) -> bool:
try:
# List of infected keywords to block (you can add more)
infected_keywords = [
r"\bINSERT\b",
r"\bUPDATE\b",
r"\bDELETE\b",
r"\bDROP\b",
r"\bEXEC\b",
r"\bALTER\b",
r"\bCREATE\b",
r"\bMERGE\b",
r"\bREPLACE\b",
r"\bTRUNCATE\b",
r"\bLOAD\b",
r"\bGRANT\b",
r"\bREVOKE\b",
r"\bCALL\b",
r"\bEXECUTE\b",
r"\bSHOW\b",
r"\bDESCRIBE\b",
r"\bEXPLAIN\b",
r"\bUSE\b",
r"\bSET\b",
r"\bDECLARE\b",
r"\bOPEN\b",
r"\bFETCH\b",
r"\bCLOSE\b",
r"\bSLEEP\b",
r"\bBENCHMARK\b",
r"\bDATABASE\b",
r"\bUSER\b",
r"\bCURRENT_USER\b",
r"\bSESSION_USER\b",
r"\bSYSTEM_USER\b",
r"\bVERSION\b",
r"\b@@VERSION\b",
r"--",
r"/\*.*\*/", # Block comments and inline comments
]
# Parse the query to extract its structure
parsed = sqlglot.parse_one(query)

# Ensure the main query is SELECT
if parsed.key.upper() != "SELECT":
return False

# Check for infected keywords in the main query
if any(
re.search(keyword, query, re.IGNORECASE) for keyword in infected_keywords
):
return False

# Check for infected keywords in subqueries
for subquery in parsed.find_all(sqlglot.exp.Subquery):
subquery_sql = subquery.sql() # Get the SQL of the subquery
if any(
re.search(keyword, subquery_sql, re.IGNORECASE)
for keyword in infected_keywords
):
return False

Check warning on line 81 in pandasai/helpers/sql_sanitizer.py

View check run for this annotation

Codecov / codecov/patch

pandasai/helpers/sql_sanitizer.py#L81

Added line #L81 was not covered by tests

return True

except sqlglot.errors.ParseError:
return False
1 change: 0 additions & 1 deletion tests/unit_tests/data_loader/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from unittest.mock import mock_open, patch

import pandas as pd
Expand Down
66 changes: 61 additions & 5 deletions tests/unit_tests/data_loader/test_sql_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import logging
from unittest.mock import MagicMock, mock_open, patch
from unittest.mock import MagicMock, patch

import pandas as pd
import pytest

from pandasai import VirtualDataFrame
from pandasai.data_loader.loader import DatasetLoader
from pandasai.data_loader.local_loader import LocalDatasetLoader
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.data_loader.sql_loader import SQLDatasetLoader
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import InvalidDataSourceType
from pandasai.exceptions import MaliciousQueryError


class TestSqlDatasetLoader:
Expand Down Expand Up @@ -138,3 +135,62 @@ def test_load_with_transformation(self, mysql_schema):
loader_function.call_args[0][1]
== "SELECT email, first_name, timestamp FROM users ORDER BY RAND() LIMIT 5"
)

def test_mysql_malicious_query(self, mysql_schema):
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
with patch(
"pandasai.data_loader.sql_loader.is_sql_query_safe"
) as mock_sql_query, patch(
"pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function"
) as mock_loader_function:
mocked_exec_function = MagicMock()
mock_df = DataFrame(
pd.DataFrame(
{
"email": ["[email protected]"],
"first_name": ["John"],
"timestamp": [pd.Timestamp.now()],
}
)
)
mocked_exec_function.return_value = mock_df
mock_loader_function.return_value = mocked_exec_function
loader = SQLDatasetLoader(mysql_schema, "test/users")
mock_sql_query.return_value = False
logging.debug("Loading schema from dataset path: %s", loader)

with pytest.raises(MaliciousQueryError):
loader.execute_query("DROP TABLE users")

mock_sql_query.assert_called_once_with("DROP TABLE users")

def test_mysql_safe_query(self, mysql_schema):
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
with patch(
"pandasai.data_loader.sql_loader.is_sql_query_safe"
) as mock_sql_query, patch(
"pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function"
) as mock_loader_function, patch(
"pandasai.data_loader.sql_loader.SQLDatasetLoader._apply_transformations"
) as mock_apply_transformations:
mocked_exec_function = MagicMock()
mock_df = DataFrame(
pd.DataFrame(
{
"email": ["[email protected]"],
"first_name": ["John"],
"timestamp": [pd.Timestamp.now()],
}
)
)
mocked_exec_function.return_value = mock_df
mock_apply_transformations.return_value = mock_df
mock_loader_function.return_value = mocked_exec_function
loader = SQLDatasetLoader(mysql_schema, "test/users")
mock_sql_query.return_value = True
logging.debug("Loading schema from dataset path: %s", loader)

result = loader.execute_query("select * from users")

assert isinstance(result, DataFrame)
mock_sql_query.assert_called_once_with("select * from users")
70 changes: 69 additions & 1 deletion tests/unit_tests/helpers/test_sql_sanitizer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
from pandasai.helpers.sql_sanitizer import is_sql_query_safe, sanitize_sql_table_name


class TestSqlSanitizer:
Expand All @@ -17,3 +17,71 @@ def test_filename_with_long_name(self):
filepath = "/path/to/" + "a" * 100 + ".csv"
expected = "a" * 64
assert sanitize_sql_table_name(filepath) == expected

def test_safe_select_query(self):
query = "SELECT * FROM users WHERE username = 'admin';"
assert is_sql_query_safe(query)

def test_safe_with_query(self):
query = "WITH user_data AS (SELECT * FROM users) SELECT * FROM user_data;"
assert is_sql_query_safe(query)

def test_unsafe_insert_query(self):
query = "INSERT INTO users (username, password) VALUES ('admin', 'password');"
assert not is_sql_query_safe(query)

def test_unsafe_update_query(self):
query = "UPDATE users SET password = 'newpassword' WHERE username = 'admin';"
assert not is_sql_query_safe(query)

def test_unsafe_delete_query(self):
query = "DELETE FROM users WHERE username = 'admin';"
assert not is_sql_query_safe(query)

def test_unsafe_drop_query(self):
query = "DROP TABLE users;"
assert not is_sql_query_safe(query)

def test_unsafe_alter_query(self):
query = "ALTER TABLE users ADD COLUMN age INT;"
assert not is_sql_query_safe(query)

def test_unsafe_create_query(self):
query = "CREATE TABLE users (id INT, username VARCHAR(50));"
assert not is_sql_query_safe(query)

def test_safe_select_with_comment(self):
query = "SELECT * FROM users WHERE username = 'admin' -- comment"
assert not is_sql_query_safe(query) # Blocked by comment detection

def test_safe_select_with_inline_comment(self):
query = "SELECT * FROM users /* inline comment */ WHERE username = 'admin';"
assert not is_sql_query_safe(query) # Blocked by comment detection

def test_unsafe_query_with_subquery(self):
query = "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders);"
assert is_sql_query_safe(query) # No dangerous keyword in main or subquery

def test_unsafe_query_with_subquery_insert(self):
query = (
"SELECT * FROM users WHERE id IN (INSERT INTO orders (user_id) VALUES (1));"
)
assert not is_sql_query_safe(query) # Subquery contains INSERT, blocked

def test_invalid_sql(self):
query = "INVALID SQL QUERY"
assert not is_sql_query_safe(query) # Invalid query should return False

def test_safe_query_with_multiple_keywords(self):
query = "SELECT name FROM users WHERE username = 'admin' AND age > 30;"
assert is_sql_query_safe(query) # Safe query with no dangerous keyword

def test_safe_query_with_subquery(self):
query = "SELECT name FROM users WHERE username IN (SELECT username FROM users WHERE age > 30);"
assert is_sql_query_safe(
query
) # Safe query with subquery, no dangerous keyword


if __name__ == "__main__":
unittest.main()
Loading