Skip to content

Commit ce303ee

Browse files
authored
Merge pull request sinaptik-ai#16 from yzaparto/zaparto/error-correcting-framework
Error Correcting Framework
2 parents ec682a7 + 68ffeb5 commit ce303ee

File tree

7 files changed

+96
-69
lines changed

7 files changed

+96
-69
lines changed

.pylintrc

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
[MASTER]
12
ignore=test_*

examples/data/sample_dataframe.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""Sample data for dataframe examples."""
2+
3+
dataframe = {
4+
"country": [
5+
"United States",
6+
"United Kingdom",
7+
"France",
8+
"Germany",
9+
"Italy",
10+
"Spain",
11+
"Canada",
12+
"Australia",
13+
"Japan",
14+
"China",
15+
],
16+
"gdp": [
17+
21400000,
18+
2940000,
19+
2830000,
20+
3870000,
21+
2160000,
22+
1350000,
23+
1780000,
24+
1320000,
25+
516000,
26+
14000000,
27+
],
28+
"happiness_index": [7.3, 7.2, 6.5, 7.0, 6.0, 6.3, 7.3, 7.3, 5.9, 5.0],
29+
}

examples/from_dataframe.py

+2-29
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,9 @@
33
import pandas as pd
44
from pandasai import PandasAI
55
from pandasai.llm.openai import OpenAI
6+
from .data.sample_dataframe import dataframe
67

7-
df = pd.DataFrame(
8-
{
9-
"country": [
10-
"United States",
11-
"United Kingdom",
12-
"France",
13-
"Germany",
14-
"Italy",
15-
"Spain",
16-
"Canada",
17-
"Australia",
18-
"Japan",
19-
"China",
20-
],
21-
"gdp": [
22-
21400000,
23-
2940000,
24-
2830000,
25-
3870000,
26-
2160000,
27-
1350000,
28-
1780000,
29-
1320000,
30-
516000,
31-
14000000,
32-
],
33-
"happiness_index": [7.3, 7.2, 6.5, 7.0, 6.0, 6.3, 7.3, 7.3, 5.9, 5.0],
34-
}
35-
)
8+
df = pd.DataFrame(dataframe)
369

3710
llm = OpenAI()
3811
pandas_ai = PandasAI(llm, verbose=True, conversational=False)

examples/with_privacy_enforced.py

+3-32
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,15 @@
33
import pandas as pd
44
from pandasai import PandasAI
55
from pandasai.llm.openai import OpenAI
6+
from .data.sample_dataframe import dataframe
67

7-
df = pd.DataFrame(
8-
{
9-
"country": [
10-
"United States",
11-
"United Kingdom",
12-
"France",
13-
"Germany",
14-
"Italy",
15-
"Spain",
16-
"Canada",
17-
"Australia",
18-
"Japan",
19-
"China",
20-
],
21-
"gdp": [
22-
21400000,
23-
2940000,
24-
2830000,
25-
3870000,
26-
2160000,
27-
1350000,
28-
1780000,
29-
1320000,
30-
516000,
31-
14000000,
32-
],
33-
"happiness_index": [7.3, 7.2, 6.5, 7.0, 6.0, 6.3, 7.3, 7.3, 5.9, 5.0],
34-
}
35-
)
8+
df = pd.DataFrame(dataframe)
369

3710
llm = OpenAI()
38-
pandas_ai = PandasAI(llm, verbose=True, conversational=False)
11+
pandas_ai = PandasAI(llm, verbose=True, conversational=False, enforce_privacy=True)
3912
response = pandas_ai.run(
4013
df,
4114
"Calculate the sum of the gdp of north american countries",
42-
enforce_privacy=True,
43-
is_conversational_answer=True,
4415
)
4516
print(response)
4617
# Output: 26200000

pandasai/__init__.py

+54-5
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,41 @@ class PandasAI:
1515
This is the result of `print(df.head({rows_to_display}))`:
1616
{df_head}.
1717
18-
Return the python code (do not import anything) to get the answer to the following question:
18+
Return the python code (do not import anything) and make sure to prefix the python code with <startCode> exactly and suffix the code with <endCode> exactly
19+
to get the answer to the following question :
1920
"""
2021
_response_instruction: str = """
2122
Question: {question}
2223
Answer: {answer}
2324
2425
Rewrite the answer to the question in a conversational way.
2526
"""
27+
28+
_error_correct_instruction: str = """
29+
For the task defined below:
30+
{orig_task}
31+
you generated this python code:
32+
{code}
33+
and this fails with the following error:
34+
{error_returned}
35+
Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error.
36+
Make sure to prefix the python code with <startCode> exactly and suffix the code with <endCode> exactly.
37+
"""
2638
_llm: LLM
2739
_verbose: bool = False
2840
_is_conversational_answer: bool = True
2941
_enforce_privacy: bool = False
42+
_max_retries: int = 3
43+
_original_instruction_and_prompt = None
3044
last_code_generated: str = None
3145
code_output: str = None
3246

3347
def __init__(
34-
self, llm=None, conversational=True, verbose=False, enforce_privacy=False
48+
self,
49+
llm=None,
50+
conversational=True,
51+
verbose=False,
52+
enforce_privacy=False,
3553
):
3654
if llm is None:
3755
raise LLMNotFoundError(
@@ -74,6 +92,13 @@ def run(
7492
),
7593
prompt,
7694
)
95+
self._original_instruction_and_prompt = (
96+
self._task_instruction.format(
97+
df_head=data_frame.head(rows_to_display),
98+
rows_to_display=rows_to_display,
99+
)
100+
+ prompt
101+
)
77102
self.last_code_generated = code
78103
self.log(
79104
f"""
@@ -83,7 +108,7 @@ def run(
83108
```"""
84109
)
85110

86-
answer = self.run_code(code, data_frame)
111+
answer = self.run_code(code, data_frame, False)
87112
self.code_output = answer
88113
self.log(f"Answer: {answer}")
89114

@@ -95,7 +120,10 @@ def run(
95120
return answer
96121

97122
def run_code(
98-
self, code: str, df: pd.DataFrame # pylint: disable=W0613 disable=C0103
123+
self,
124+
code: str,
125+
df: pd.DataFrame, # pylint: disable=W0613 disable=C0103
126+
use_error_correction_framework: bool = False,
99127
) -> str:
100128
# pylint: disable=W0122 disable=W0123 disable=W0702:bare-except
101129
"""Run the code in the current context and return the result"""
@@ -105,7 +133,28 @@ def run_code(
105133
sys.stdout = output
106134

107135
# Execute the code
108-
exec(code)
136+
if use_error_correction_framework:
137+
count = 0
138+
code_to_run = code
139+
while count < self._max_retries:
140+
try:
141+
exec(code_to_run)
142+
code = code_to_run
143+
break
144+
except Exception as e: # pylint: disable=W0718 disable=C0103
145+
count += 1
146+
error_correcting_instruction = (
147+
self._error_correct_instruction.format(
148+
orig_task=self._original_instruction_and_prompt,
149+
code=code,
150+
error_returned=e,
151+
)
152+
)
153+
code_to_run = self._llm.generate_code(
154+
error_correcting_instruction, ""
155+
)
156+
else:
157+
exec(code)
109158

110159
# Restore standard output and get the captured output
111160
sys.stdout = sys.__stdout__

pandasai/llm/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def _extract_code(self, response: str, separator: str = "```") -> str:
5959
code = response
6060
if len(response.split(separator)) > 1:
6161
code = response.split(separator)[1]
62+
if re.match(r"<startCode>([\s\S]*?)<\/?endCode>", code):
63+
code = re.findall(r"<startCode>([\s\S]*?)<\/?endCode>", code)[0]
6264
code = self._polish_code(code)
6365
if not self._is_python_code(code):
6466
raise NoCodeFoundError("No code found in the response")

tests/test_pandasai.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def test_run_with_privacy_enforcement(self):
139139
Columns: [country]
140140
Index: [].
141141
142-
Return the python code (do not import anything) to get the answer to the following question:
142+
Return the python code (do not import anything) and make sure to prefix the python code with <startCode> exactly and suffix the code with <endCode> exactly
143+
to get the answer to the following question :
143144
How many countries are in the dataframe?"""
144145
self.pandasai.run(df, "How many countries are in the dataframe?")
145146
assert self.pandasai._llm.last_prompt == expected_prompt
@@ -159,7 +160,8 @@ def test_run_without_privacy_enforcement(self):
159160
1 United Kingdom
160161
2 France.
161162
162-
Return the python code (do not import anything) to get the answer to the following question:
163+
Return the python code (do not import anything) and make sure to prefix the python code with <startCode> exactly and suffix the code with <endCode> exactly
164+
to get the answer to the following question :
163165
How many countries are in the dataframe?"""
164166
self.pandasai.run(df, "How many countries are in the dataframe?")
165-
assert self.pandasai._llm.last_prompt == expected_prompt
167+
assert self.pandasai._llm.last_prompt == expected_prompt

0 commit comments

Comments
 (0)