Skip to content

Commit 3cf1f7d

Browse files
authored
Support models used by text2sql project [Initial commit] (#1)
1 parent d4c1d82 commit 3cf1f7d

File tree

5 files changed

+105
-23
lines changed

5 files changed

+105
-23
lines changed

pyspark_ai/ai_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def __init__(self, spark_ai, df_instance: DataFrame):
2323
self.spark_ai = spark_ai
2424
self.df_instance = df_instance
2525

26+
def transform_tpch(self, desc: str, table: str, cache: bool = False) -> DataFrame:
27+
return self.spark_ai.transform_df_tpch(desc, table, cache)
28+
2629
def transform(self, desc: str, cache: bool = True) -> DataFrame:
2730
"""
2831
Transform the DataFrame using the given description.

pyspark_ai/prompt.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,23 +171,25 @@
171171
spark_sql_shared_example_4,
172172
]
173173

174-
SPARK_SQL_SUFFIX = """\nQuestion: Given a Spark temp view `{view_name}` {comment}.
175-
176-
Here are column names and sample values from each column, to help you understand the columns in the dataframe.
177-
The format will be (column_name, type, [sample_value_1, sample_value_2...])...
178-
Use these column names and sample values to help you choose which columns to query.
179-
It's very important to ONLY use the verbatim column_name in your resulting SQL query; DO NOT include the type.
174+
SPARK_SQL_SUFFIX = """\nQuestion: Given a Spark temp view `{view_name}` {comment} with the following sample vals,
175+
in the format (column_name, type, [sample_value_1, sample_value_2...]):
176+
```
180177
{sample_vals}
181-
182-
Write a Spark SQL query to retrieve the following from view `{view_name}`: {desc}
178+
```
179+
Write a Spark SQL query to retrieve from view `{view_name}`: {desc}
180+
Answer:
183181
"""
184182

185183
SPARK_SQL_SUFFIX_FOR_AGENT = SPARK_SQL_SUFFIX + "\n{agent_scratchpad}"
186184

187185
SPARK_SQL_PREFIX = """You are an assistant for writing professional Spark SQL queries.
188-
Given a question, you need to write a Spark SQL query to answer the question. The result is ALWAYS a Spark SQL query.
189-
Use the COUNT SQL function when the query asks for total number of some non-countable column.
190-
Use the SUM SQL function to accumulate the total number of countable column values."""
186+
Given a question, you need to write a Spark SQL query to answer the question.
187+
The rules that you should follow for answering question:
188+
1.The answer only consists of Spark SQL query. No explaination. No
189+
2.SQL statements should be Spark SQL query.
190+
3.ONLY use the verbatim column_name in your resulting SQL query; DO NOT include the type.
191+
4.Use the COUNT SQL function when the query asks for total number of some non-countable column.
192+
5.Use the SUM SQL function to accumulate the total number of countable column values."""
191193

192194
SPARK_SQL_PREFIX_VECTOR_SEARCH = (
193195
SPARK_SQL_PREFIX

pyspark_ai/pyspark_ai.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
import os
44
import re
5+
import time
56
from typing import Callable, List, Optional
67
from urllib.parse import urlparse
78

@@ -480,6 +481,7 @@ def _get_sql_query(
480481
)
481482
else:
482483
# Otherwise, generate the SQL query with a prompt with few-shot examples
484+
print(f"-------------------------Start generating sql query with a prompt with few-shot examples-------------------------\n\n")
483485
return self.sql_chain.run(
484486
view_name=temp_view_name,
485487
sample_vals=sample_vals_str,
@@ -496,6 +498,7 @@ def _get_transform_sql_query(self, df: DataFrame, desc: str, cache: bool) -> str
496498
df.createOrReplaceTempView(temp_view_name)
497499
schema_lst = self._get_df_schema(df)
498500
schema_str = "\n".join(schema_lst)
501+
print(f"-------------------------Current table schema from df is:-------------------------\n\n {schema_str}\n")
499502
sample_rows = self._get_sample_spark_rows(df)
500503
schema_row_lst = []
501504
for index in range(len(schema_lst)):
@@ -505,8 +508,9 @@ def _get_transform_sql_query(self, df: DataFrame, desc: str, cache: bool) -> str
505508
curr_schema_row = f"({schema_lst[index]}, {str(sample_vals)})"
506509
schema_row_lst.append(curr_schema_row)
507510
sample_vals_str = "\n".join([str(val) for val in schema_row_lst])
511+
print(f"-------------------------Current sample vals are:-------------------------\n\n {sample_vals_str}\n")
508512
comment = self._get_table_comment(df)
509-
513+
print(f"-------------------------Current table comment is-------------------------\n\n {comment}\n")
510514
if cache:
511515
cache_key = ReActSparkSQLAgent.cache_key(desc, schema_str)
512516
cached_result = self._cache.lookup(key=cache_key)
@@ -523,6 +527,65 @@ def _get_transform_sql_query(self, df: DataFrame, desc: str, cache: bool) -> str
523527
else:
524528
return self._get_sql_query(temp_view_name, sample_vals_str, comment, desc)
525529

530+
531+
def _get_table_schema(self, table: str) -> list:
532+
df = self._spark.sql(f"select * from {table}")
533+
schema_lst = [f"{name}, {dtype}" for name, dtype in df.dtypes]
534+
return schema_lst
535+
536+
def _get_sample_spark_rows(self, df: DataFrame) -> list:
537+
538+
if self._sample_rows_in_table_info <= 0:
539+
return []
540+
try:
541+
sample_rows = SparkUtils.get_dataframe_results(df.limit(3))
542+
return sample_rows
543+
except Exception:
544+
# If fail to get sample rows, return empty list
545+
return []
546+
547+
def _get_sample_spark_rows_tpch(self, table: str) -> list:
548+
549+
if self._sample_rows_in_table_info <= 0:
550+
return []
551+
df = self._spark.sql(f"select * from {table}")
552+
try:
553+
sample_rows = SparkUtils.get_dataframe_results(df.limit(3))
554+
return sample_rows
555+
except Exception:
556+
# If fail to get sample rows, return empty list
557+
return []
558+
559+
def _get_transform_sql_query_tpch(self, desc: str, table: str, cache: bool) -> str:
560+
self.log(f"Retrieve table schema for {table} \n")
561+
schema_lst = self._get_table_schema(table)
562+
schema_str = "\n".join(schema_lst)
563+
print(f"-------------------------Current table schema from df is:-------------------------\n\n {schema_str}\n")
564+
sample_rows = self._get_sample_spark_rows_tpch(table)
565+
schema_row_lst = []
566+
for index in range(len(schema_lst)):
567+
sample_vals = []
568+
for sample_row in sample_rows:
569+
sample_vals.append(sample_row[index])
570+
curr_schema_row = f"({schema_lst[index]}, {str(sample_vals)})"
571+
schema_row_lst.append(curr_schema_row)
572+
sample_vals_str = "\n".join([str(val) for val in schema_row_lst])
573+
print(f"-------------------------Current sample vals are:-------------------------\n\n {sample_vals_str}\n")
574+
#comment = self._get_table_comment(df)
575+
comment = ""
576+
#print(f"-------------------------Current table comment is-------------------------\n\n {comment}\n")
577+
return self._get_sql_query(table, sample_vals_str, comment, desc)
578+
579+
def transform_df_tpch(self, desc: str, table: str, cache: bool = False) -> DataFrame:
580+
print(f"---------------------TPCH Table {table}------------------------------\n\n")
581+
start_time = time.time()
582+
sql_query = self._get_transform_sql_query_tpch(desc, table, cache)
583+
end_time = time.time()
584+
get_transform_sql_query_time = end_time - start_time
585+
print(f"-------------------------End get_transform_sql_query-------------------------\n\n get_transform_sql_query_time: {get_transform_sql_query_time} seconds\n")
586+
print(f"-------------------------Received query:-------------------------\n\n {sql_query}\n")
587+
return self._spark.sql(sql_query)
588+
526589
def transform_df(self, df: DataFrame, desc: str, cache: bool = True) -> DataFrame:
527590
"""
528591
This method applies a transformation to a provided Spark DataFrame,
@@ -535,7 +598,13 @@ def transform_df(self, df: DataFrame, desc: str, cache: bool = True) -> DataFram
535598
:return: Returns a new Spark DataFrame that is the result of applying the specified transformation
536599
on the input DataFrame.
537600
"""
601+
print(f"-------------------------Start get_transform_sql_query-------------------------\n\n")
602+
start_time = time.time()
538603
sql_query = self._get_transform_sql_query(df, desc, cache)
604+
end_time = time.time()
605+
get_transform_sql_query_time = end_time - start_time
606+
print(f"-------------------------End get_transform_sql_query-------------------------\n\n get_transform_sql_query_time: {get_transform_sql_query_time} seconds\n")
607+
print(f"-------------------------Received query:-------------------------\n\n {sql_query}\n")
539608
return self._spark.sql(sql_query)
540609

541610
def explain_df(self, df: DataFrame, cache: bool = True) -> str:

pyspark_ai/python_executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def run(
3838
) -> str:
3939
assert not args, "The chain expected no arguments"
4040
# assert llm is an instance of BaseChatModel
41-
assert isinstance(
42-
self.llm, BaseChatModel
43-
), "The llm is not an instance of BaseChatModel"
41+
#assert isinstance(
42+
# self.llm, BaseChatModel
43+
#), "The llm is not an instance of BaseChatModel"
4444
prompt_str = canonize_string(self.prompt.format_prompt(**kwargs).to_string())
4545
use_cache = tags != SKIP_CACHE_TAGS
4646
if self.cache is not None:

pyspark_ai/spark_sql_chain.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from langchain.chat_models.base import BaseChatModel
66
from langchain.schema import BaseMessage, HumanMessage
77
from pyspark.sql import SparkSession
8+
from langchain_core.language_models.llms import BaseLLM
89

910
from pyspark_ai.ai_utils import AIUtils
1011
from pyspark_ai.code_logger import CodeLogger
@@ -29,24 +30,28 @@ def run(
2930
) -> str:
3031
assert not args, "The chain expected no arguments"
3132
# assert llm is an instance of BaseChatModel
32-
assert isinstance(
33-
self.llm, BaseChatModel
34-
), "The llm is not an instance of BaseChatModel"
33+
#assert isinstance(
34+
# self.llm, BaseChatModel
35+
#), "The llm is not an instance of BaseChatModel"
3536
prompt_str = self.prompt.format_prompt(**kwargs).to_string()
37+
print(f"-------------------------Input prompt is:-------------------------\n\n {prompt_str}\n")
3638
messages = [HumanMessage(content=prompt_str)]
3739
return self._generate_code_with_retries(self.llm, messages, self.max_retries)
3840

3941
def _generate_code_with_retries(
4042
self,
41-
chat_model: BaseChatModel,
43+
chat_model: BaseLLM,
4244
messages: List[BaseMessage],
4345
retries: int = 3,
4446
) -> str:
4547
response = chat_model.predict_messages(messages)
46-
if self.logger is not None:
47-
self.logger.info(response.content)
48+
print(f"-------------------------The model replies:-------------------------\n\n {response.content} \n")
49+
#if self.logger is not None:
50+
# self.logger.info(response.content)
4851
code = AIUtils.extract_code_blocks(response.content)[0]
52+
#code = response.content.split("\n")[1].split("Human:")[1].replace("`","")
4953
try:
54+
print(f"-------------------------Spark retrieved sql:-------------------------\n\n {code}\n")
5055
self.spark.sql(code)
5156
return code
5257
except Exception as e:
@@ -61,7 +66,10 @@ def _generate_code_with_retries(
6166
if self.logger is not None:
6267
self.logger.info("Retrying with " + str(retries) + " retries left")
6368

64-
messages.append(response)
69+
# messages.append(response)
70+
# Remove retry logic to prevent long response append and ensure accurate model results.
71+
6572
# append the exception as a HumanMessage into messages
66-
messages.append(HumanMessage(content=str(e)))
73+
# messages.append(HumanMessage(content=str(e)))
74+
# Remove retry logic to prevent long response append and ensure accurate model results.
6775
return self._generate_code_with_retries(chat_model, messages, retries - 1)

0 commit comments

Comments
 (0)