2
2
import io
3
3
import os
4
4
import re
5
+ import time
5
6
from typing import Callable , List , Optional
6
7
from urllib .parse import urlparse
7
8
@@ -480,6 +481,7 @@ def _get_sql_query(
480
481
)
481
482
else :
482
483
# Otherwise, generate the SQL query with a prompt with few-shot examples
484
+ print (f"-------------------------Start generating sql query with a prompt with few-shot examples-------------------------\n \n " )
483
485
return self .sql_chain .run (
484
486
view_name = temp_view_name ,
485
487
sample_vals = sample_vals_str ,
@@ -496,6 +498,7 @@ def _get_transform_sql_query(self, df: DataFrame, desc: str, cache: bool) -> str
496
498
df .createOrReplaceTempView (temp_view_name )
497
499
schema_lst = self ._get_df_schema (df )
498
500
schema_str = "\n " .join (schema_lst )
501
+ print (f"-------------------------Current table schema from df is:-------------------------\n \n { schema_str } \n " )
499
502
sample_rows = self ._get_sample_spark_rows (df )
500
503
schema_row_lst = []
501
504
for index in range (len (schema_lst )):
@@ -505,8 +508,9 @@ def _get_transform_sql_query(self, df: DataFrame, desc: str, cache: bool) -> str
505
508
curr_schema_row = f"({ schema_lst [index ]} , { str (sample_vals )} )"
506
509
schema_row_lst .append (curr_schema_row )
507
510
sample_vals_str = "\n " .join ([str (val ) for val in schema_row_lst ])
511
+ print (f"-------------------------Current sample vals are:-------------------------\n \n { sample_vals_str } \n " )
508
512
comment = self ._get_table_comment (df )
509
-
513
+ print ( f"-------------------------Current table comment is------------------------- \n \n { comment } \n " )
510
514
if cache :
511
515
cache_key = ReActSparkSQLAgent .cache_key (desc , schema_str )
512
516
cached_result = self ._cache .lookup (key = cache_key )
@@ -523,6 +527,65 @@ def _get_transform_sql_query(self, df: DataFrame, desc: str, cache: bool) -> str
523
527
else :
524
528
return self ._get_sql_query (temp_view_name , sample_vals_str , comment , desc )
525
529
530
+
531
+ def _get_table_schema (self , table : str ) -> list :
532
+ df = self ._spark .sql (f"select * from { table } " )
533
+ schema_lst = [f"{ name } , { dtype } " for name , dtype in df .dtypes ]
534
+ return schema_lst
535
+
536
+ def _get_sample_spark_rows (self , df : DataFrame ) -> list :
537
+
538
+ if self ._sample_rows_in_table_info <= 0 :
539
+ return []
540
+ try :
541
+ sample_rows = SparkUtils .get_dataframe_results (df .limit (3 ))
542
+ return sample_rows
543
+ except Exception :
544
+ # If fail to get sample rows, return empty list
545
+ return []
546
+
547
+ def _get_sample_spark_rows_tpch (self , table : str ) -> list :
548
+
549
+ if self ._sample_rows_in_table_info <= 0 :
550
+ return []
551
+ df = self ._spark .sql (f"select * from { table } " )
552
+ try :
553
+ sample_rows = SparkUtils .get_dataframe_results (df .limit (3 ))
554
+ return sample_rows
555
+ except Exception :
556
+ # If fail to get sample rows, return empty list
557
+ return []
558
+
559
+ def _get_transform_sql_query_tpch (self , desc : str , table : str , cache : bool ) -> str :
560
+ self .log (f"Retrieve table schema for { table } \n " )
561
+ schema_lst = self ._get_table_schema (table )
562
+ schema_str = "\n " .join (schema_lst )
563
+ print (f"-------------------------Current table schema from df is:-------------------------\n \n { schema_str } \n " )
564
+ sample_rows = self ._get_sample_spark_rows_tpch (table )
565
+ schema_row_lst = []
566
+ for index in range (len (schema_lst )):
567
+ sample_vals = []
568
+ for sample_row in sample_rows :
569
+ sample_vals .append (sample_row [index ])
570
+ curr_schema_row = f"({ schema_lst [index ]} , { str (sample_vals )} )"
571
+ schema_row_lst .append (curr_schema_row )
572
+ sample_vals_str = "\n " .join ([str (val ) for val in schema_row_lst ])
573
+ print (f"-------------------------Current sample vals are:-------------------------\n \n { sample_vals_str } \n " )
574
+ #comment = self._get_table_comment(df)
575
+ comment = ""
576
+ #print(f"-------------------------Current table comment is-------------------------\n\n {comment}\n")
577
+ return self ._get_sql_query (table , sample_vals_str , comment , desc )
578
+
579
+ def transform_df_tpch (self , desc : str , table : str , cache : bool = False ) -> DataFrame :
580
+ print (f"---------------------TPCH Table { table } ------------------------------\n \n " )
581
+ start_time = time .time ()
582
+ sql_query = self ._get_transform_sql_query_tpch (desc , table , cache )
583
+ end_time = time .time ()
584
+ get_transform_sql_query_time = end_time - start_time
585
+ print (f"-------------------------End get_transform_sql_query-------------------------\n \n get_transform_sql_query_time: { get_transform_sql_query_time } seconds\n " )
586
+ print (f"-------------------------Received query:-------------------------\n \n { sql_query } \n " )
587
+ return self ._spark .sql (sql_query )
588
+
526
589
def transform_df (self , df : DataFrame , desc : str , cache : bool = True ) -> DataFrame :
527
590
"""
528
591
This method applies a transformation to a provided Spark DataFrame,
@@ -535,7 +598,13 @@ def transform_df(self, df: DataFrame, desc: str, cache: bool = True) -> DataFram
535
598
:return: Returns a new Spark DataFrame that is the result of applying the specified transformation
536
599
on the input DataFrame.
537
600
"""
601
+ print (f"-------------------------Start get_transform_sql_query-------------------------\n \n " )
602
+ start_time = time .time ()
538
603
sql_query = self ._get_transform_sql_query (df , desc , cache )
604
+ end_time = time .time ()
605
+ get_transform_sql_query_time = end_time - start_time
606
+ print (f"-------------------------End get_transform_sql_query-------------------------\n \n get_transform_sql_query_time: { get_transform_sql_query_time } seconds\n " )
607
+ print (f"-------------------------Received query:-------------------------\n \n { sql_query } \n " )
539
608
return self ._spark .sql (sql_query )
540
609
541
610
def explain_df (self , df : DataFrame , cache : bool = True ) -> str :
0 commit comments