|
1 | 1 | import logging
|
2 |
| -from unittest.mock import MagicMock, mock_open, patch |
| 2 | +from unittest.mock import MagicMock, patch |
3 | 3 |
|
4 | 4 | import pandas as pd
|
5 | 5 | import pytest
|
6 | 6 |
|
7 | 7 | from pandasai import VirtualDataFrame
|
8 |
| -from pandasai.data_loader.loader import DatasetLoader |
9 |
| -from pandasai.data_loader.local_loader import LocalDatasetLoader |
10 |
| -from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema |
11 | 8 | from pandasai.data_loader.sql_loader import SQLDatasetLoader
|
12 | 9 | from pandasai.dataframe.base import DataFrame
|
13 |
| -from pandasai.exceptions import InvalidDataSourceType |
| 10 | +from pandasai.exceptions import MaliciousQueryError |
14 | 11 |
|
15 | 12 |
|
16 | 13 | class TestSqlDatasetLoader:
|
@@ -138,3 +135,62 @@ def test_load_with_transformation(self, mysql_schema):
|
138 | 135 | loader_function.call_args[0][1]
|
139 | 136 | == "SELECT email, first_name, timestamp FROM users ORDER BY RAND() LIMIT 5"
|
140 | 137 | )
|
| 138 | + |
| 139 | + def test_mysql_malicious_query(self, mysql_schema): |
| 140 | + """Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly.""" |
| 141 | + with patch( |
| 142 | + "pandasai.data_loader.sql_loader.is_sql_query_safe" |
| 143 | + ) as mock_sql_query, patch( |
| 144 | + "pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function" |
| 145 | + ) as mock_loader_function: |
| 146 | + mocked_exec_function = MagicMock() |
| 147 | + mock_df = DataFrame( |
| 148 | + pd.DataFrame( |
| 149 | + { |
| 150 | + |
| 151 | + "first_name": ["John"], |
| 152 | + "timestamp": [pd.Timestamp.now()], |
| 153 | + } |
| 154 | + ) |
| 155 | + ) |
| 156 | + mocked_exec_function.return_value = mock_df |
| 157 | + mock_loader_function.return_value = mocked_exec_function |
| 158 | + loader = SQLDatasetLoader(mysql_schema, "test/users") |
| 159 | + mock_sql_query.return_value = False |
| 160 | + logging.debug("Loading schema from dataset path: %s", loader) |
| 161 | + |
| 162 | + with pytest.raises(MaliciousQueryError): |
| 163 | + loader.execute_query("DROP TABLE users") |
| 164 | + |
| 165 | + mock_sql_query.assert_called_once_with("DROP TABLE users") |
| 166 | + |
| 167 | + def test_mysql_safe_query(self, mysql_schema): |
| 168 | + """Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly.""" |
| 169 | + with patch( |
| 170 | + "pandasai.data_loader.sql_loader.is_sql_query_safe" |
| 171 | + ) as mock_sql_query, patch( |
| 172 | + "pandasai.data_loader.sql_loader.SQLDatasetLoader._get_loader_function" |
| 173 | + ) as mock_loader_function, patch( |
| 174 | + "pandasai.data_loader.sql_loader.SQLDatasetLoader._apply_transformations" |
| 175 | + ) as mock_apply_transformations: |
| 176 | + mocked_exec_function = MagicMock() |
| 177 | + mock_df = DataFrame( |
| 178 | + pd.DataFrame( |
| 179 | + { |
| 180 | + |
| 181 | + "first_name": ["John"], |
| 182 | + "timestamp": [pd.Timestamp.now()], |
| 183 | + } |
| 184 | + ) |
| 185 | + ) |
| 186 | + mocked_exec_function.return_value = mock_df |
| 187 | + mock_apply_transformations.return_value = mock_df |
| 188 | + mock_loader_function.return_value = mocked_exec_function |
| 189 | + loader = SQLDatasetLoader(mysql_schema, "test/users") |
| 190 | + mock_sql_query.return_value = True |
| 191 | + logging.debug("Loading schema from dataset path: %s", loader) |
| 192 | + |
| 193 | + result = loader.execute_query("select * from users") |
| 194 | + |
| 195 | + assert isinstance(result, DataFrame) |
| 196 | + mock_sql_query.assert_called_once_with("select * from users") |
0 commit comments