Skip to content

Commit

Permalink
fix(chart): charts to save to save_chart_path (#710)
Browse files Browse the repository at this point in the history
* fix(chart): charts to save to save_chart_path

* refactor sourcery changes

* 'Refactored by Sourcery'

* refactor chart save code

* fix: minor leftovers

---------

Co-authored-by: Sourcery AI <>
Co-authored-by: Gabriele Venturi <[email protected]>
  • Loading branch information
ArslanSaleem and gventuri authored Oct 31, 2023
1 parent f839482 commit 5a5155e
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 47 deletions.
30 changes: 15 additions & 15 deletions .sourcery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@

# This file was auto-generated by Sourcery on 2023-10-28 at 17:16.

version: '1' # The schema version of this config file
version: "1" # The schema version of this config file

ignore: # A list of paths or files which Sourcery will ignore.
- .git
- venv
- .venv
- env
- .env
- .tox
- node_modules
- vendor
- .git
- venv
- .venv
- env
- .env
- .tox
- node_modules
- vendor

rule_settings:
enable:
- default
disable: [] # A list of rule IDs Sourcery will never suggest.
- default
disable: ["no-conditionals-in-tests"] # A list of rule IDs Sourcery will never suggest.
rule_types:
- refactoring
- suggestion
- comment
python_version: '3.9' # A string specifying the lowest Python version your project supports. Sourcery will not suggest refactorings requiring a higher Python version.
- refactoring
- suggestion
- comment
python_version: "3.9" # A string specifying the lowest Python version your project supports. Sourcery will not suggest refactorings requiring a higher Python version.

# rules: # A list of custom rules Sourcery will include in its analysis.
# - id: no-print-statements
Expand Down
2 changes: 2 additions & 0 deletions pandasai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
While List Builtin Methods.
"""
# Default directory to store chart if user doesn't provide any
DEFAULT_CHART_DIRECTORY = "exports/charts"

# List of Python builtin libraries that are added to the environment by default.
WHITELISTED_BUILTINS = [
Expand Down
3 changes: 2 additions & 1 deletion pandasai/schemas/df_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel, validator, Field
from typing import Optional, List, Any, Dict, Type, TypedDict
from pandasai.constants import DEFAULT_CHART_DIRECTORY

from pandasai.responses import ResponseParser
from ..middlewares.base import Middleware
Expand All @@ -25,7 +26,7 @@ class Config(BaseModel):
custom_instructions: Optional[str] = None
open_charts: bool = True
save_charts: bool = False
save_charts_path: str = "exports/charts"
save_charts_path: str = DEFAULT_CHART_DIRECTORY
custom_whitelisted_dependencies: List[str] = Field(default_factory=list)
max_retries: int = 3
middlewares: List[Middleware] = Field(default_factory=list)
Expand Down
17 changes: 13 additions & 4 deletions pandasai/smart_datalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import logging
import os
import traceback
from pandasai.constants import DEFAULT_CHART_DIRECTORY
from pandasai.helpers.skills_manager import SkillsManager

from pandasai.skills import skill
Expand Down Expand Up @@ -151,10 +152,18 @@ def initialize(self):
"""

if self._config.save_charts:
try:
charts_dir = os.path.join((find_project_root()), "exports", "charts")
except ValueError:
charts_dir = os.path.join(os.getcwd(), "exports", "charts")
charts_dir = self._config.save_charts_path

# Add project root path if save_charts_path is default
if self._config.save_charts_path == DEFAULT_CHART_DIRECTORY:
try:
charts_dir = os.path.join(
(find_project_root()), self._config.save_charts_path
)
except ValueError:
charts_dir = os.path.join(
os.getcwd(), self._config.save_charts_path
)
os.makedirs(charts_dir, mode=0o777, exist_ok=True)

if self._config.enable_cache:
Expand Down
4 changes: 3 additions & 1 deletion tests/prompts/test_generate_python_code_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def test_advanced_reasoning_prompt(self):
prompt.set_config(dfs[0]._lake.config)
prompt.set_var("dfs", dfs)
prompt.set_var("conversation", "Question")
prompt.set_var("save_charts_path", "")
prompt.set_var("output_type_hint", "")
prompt.set_var("skills", "")
prompt.set_var("viz_library_type", "")
Expand Down Expand Up @@ -177,6 +176,9 @@ def analyze_data(dfs: list[pd.DataFrame]) -> dict:
- answer to the user as you would do as a data analyst; wrap it between <answer> tags; do not include the value or the chart itself (it will be calculated later).
- return the updated analyze_data function wrapped within ```python ```''' # noqa E501
actual_prompt_content = prompt.to_string()

print(expected_prompt_content)

if sys.platform.startswith("win"):
actual_prompt_content = actual_prompt_content.replace("\r\n", "\n")
assert actual_prompt_content == expected_prompt_content
Expand Down
18 changes: 9 additions & 9 deletions tests/test_smartdataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def test_updates_verbose_config_with_setters(self, smart_dataframe: SmartDatafra

smart_dataframe.verbose = True
assert smart_dataframe.verbose
assert smart_dataframe.lake._logger.verbose is True
assert smart_dataframe.lake._logger.verbose
assert len(smart_dataframe.lake._logger._logger.handlers) == 1
assert isinstance(
smart_dataframe.lake._logger._logger.handlers[0], logging.StreamHandler
Expand All @@ -615,16 +615,16 @@ def test_updates_verbose_config_with_setters(self, smart_dataframe: SmartDatafra
def test_updates_save_logs_config_with_setters(
self, smart_dataframe: SmartDataframe
):
assert smart_dataframe.save_logs is True
assert smart_dataframe.save_logs

smart_dataframe.save_logs = False
assert not smart_dataframe.save_logs
assert smart_dataframe.lake._logger.save_logs is False
assert not smart_dataframe.lake._logger.save_logs
assert len(smart_dataframe.lake._logger._logger.handlers) == 0

smart_dataframe.save_logs = True
assert smart_dataframe.save_logs
assert smart_dataframe.lake._logger.save_logs is True
assert smart_dataframe.lake._logger.save_logs
assert len(smart_dataframe.lake._logger._logger.handlers) == 1
assert isinstance(
smart_dataframe.lake._logger._logger.handlers[0], logging.FileHandler
Expand All @@ -637,7 +637,7 @@ def test_updates_enable_cache_config_with_setters(

smart_dataframe.enable_cache = True
assert smart_dataframe.enable_cache
assert smart_dataframe.lake.enable_cache is True
assert smart_dataframe.lake.enable_cache
assert smart_dataframe.lake.cache is not None
assert isinstance(smart_dataframe.lake._cache, Cache)

Expand All @@ -649,7 +649,7 @@ def test_updates_enable_cache_config_with_setters(
def test_updates_configs_with_setters(self, smart_dataframe: SmartDataframe):
assert smart_dataframe.callback is None
assert smart_dataframe.enforce_privacy is False
assert smart_dataframe.use_error_correction_framework is True
assert smart_dataframe.use_error_correction_framework
assert smart_dataframe.custom_prompts == {}
assert smart_dataframe.save_charts is False
assert smart_dataframe.save_charts_path == "exports/charts"
Expand Down Expand Up @@ -959,7 +959,7 @@ class TestSchema(BaseModel):

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is True
assert validation_result.passed

def test_pydantic_validate_false(self, llm):
# Create a sample DataFrame
Expand Down Expand Up @@ -994,7 +994,7 @@ class TestSchema(BaseModel):
B: int

validation_result = df_object.validate(TestSchema)
assert validation_result.passed is True
assert validation_result.passed

def test_pydantic_validate_false_one_record(self, llm):
# Create a sample DataFrame
Expand Down Expand Up @@ -1039,7 +1039,7 @@ class TestSchema(BaseModel):

validation_result = df_object.validate(TestSchema)

assert validation_result.passed is True
assert validation_result.passed

def test_head_csv_with_sample_head(
self, sample_head, data_sampler, smart_dataframe: SmartDataframe
Expand Down
41 changes: 24 additions & 17 deletions tests/test_smartdatalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,28 +181,35 @@ def analyze_data(df):
""" # noqa: E501
)

@pytest.mark.parametrize(
"save_charts,enable_cache",
[(False, False), (False, True), (True, False), (True, True)],
)
@patch("os.makedirs")
def test_initialize(
self, mock_makedirs, smart_datalake: SmartDatalake, save_charts, enable_cache
):
smart_datalake.config.save_charts = save_charts
smart_datalake.config.enable_cache = enable_cache
def test_initialize_with_cache(self, mock_makedirs, smart_datalake):
# Modify the smart_datalake's configuration
smart_datalake.config.save_charts = True
smart_datalake.config.enable_cache = True

# Call the initialize method
smart_datalake.initialize()

if not save_charts and not enable_cache:
mock_makedirs.assert_not_called()
# Assertions for enabling cache
cache_dir = os.path.join(os.getcwd(), "cache")
mock_makedirs.assert_any_call(cache_dir, mode=0o777, exist_ok=True)

# Assertions for saving charts
charts_dir = os.path.join(os.getcwd(), smart_datalake.config.save_charts_path)
mock_makedirs.assert_any_call(charts_dir, mode=0o777, exist_ok=True)

if save_charts:
charts_dir = os.path.join(os.getcwd(), "exports", "charts")
mock_makedirs.assert_any_call(charts_dir, mode=0o777, exist_ok=True)
@patch("os.makedirs")
def test_initialize_without_cache(self, mock_makedirs, smart_datalake):
# Modify the smart_datalake's configuration
smart_datalake.config.save_charts = True
smart_datalake.config.enable_cache = False

# Call the initialize method
smart_datalake.initialize()

if enable_cache:
cache_dir = os.path.join(os.getcwd(), "cache")
mock_makedirs.assert_any_call(cache_dir, mode=0o777, exist_ok=True)
# Assertions for saving charts
charts_dir = os.path.join(os.getcwd(), smart_datalake.config.save_charts_path)
mock_makedirs.assert_called_once_with(charts_dir, mode=0o777, exist_ok=True)

def test_last_answer_and_reasoning(self, smart_datalake: SmartDatalake):
llm = FakeLLM(
Expand Down

0 comments on commit 5a5155e

Please sign in to comment.