Skip to content

Commit

Permalink
Fix string parsing errors #29
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Oct 11, 2023
1 parent b4466dd commit 39fe7e7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
5 changes: 4 additions & 1 deletion sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ def query_api(
res, alt_res = sql_g.generate_sql(
table_names, question, model_name=model_name, _dialect=db_dialect
)
res = res.replace("“", '"').replace("”", '"')
[res := res.replace(s, '"') for s in "‘`’'" if s in res]
logger.info(f"Input query: {question}")
logger.info(f"Generated response:\n\n{res}")

Expand All @@ -495,7 +497,8 @@ def query_api(

# Before executing, check if known vulnerabilities exist in the generated SQL code.
_val = _val.replace("“", '"').replace("”", '"')
[_val := _val.replace(s, "'") for s in "‘`" if s in _val]
[_val := _val.replace(s, '"') for s in "‘`’'" if s in _val]

r, m = check_vulnerability(_val)
if not r:
q_res, err = db_obj.execute_query_db(query=_val)
Expand Down
11 changes: 7 additions & 4 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,16 +323,19 @@ def generate_sql(
else:
# TODO Update needed for multiple tables
columns_w_type = (
self.context_builder.full_context_dict[table_name].split(":")[2].split("and")[0].strip()
self.context_builder.full_context_dict[table_name]
.split(":")[2]
.split(" and foreign keys")[0]
.strip()
)

data_samples_list = self.load_column_samples(table_names)

_context = {
"if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL",
"if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT",
"detailed summary": "include min, avg, max",
"summary": "include min, avg, max",
"detailed summary": "include min, avg, max for numeric columns",
"summary": "include min, avg, max for numeric columns",
}

m_path = f"{self.path}/models/sentence_transformers/"
Expand Down Expand Up @@ -408,7 +411,7 @@ def generate_sql(
]
data_samples_list = contextual_data_samples

relevant_columns = context_columns if len(context_columns) > 0 else clmn_names
relevant_columns = context_columns if len(context_columns) > 0 else [columns_w_type]
_column_info = ", ".join(relevant_columns)

logger.debug(f"Relevant sample column values: {data_samples_list}")
Expand Down

0 comments on commit 39fe7e7

Please sign in to comment.