Skip to content

Commit

Permalink
Additional syntax based dialect validation #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Aug 9, 2023
1 parent d6a38b7 commit 97b2c78
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
15 changes: 14 additions & 1 deletion sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,20 @@ def generate_sql(
# COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
_temp = _res.replace("table_name", table_names[0]).split(";")[0]
res = "SELECT" + _temp + " COLLATE NOCASE;"
return res

# Validate the generate SQL for parsing errors, along with dialect specific validation
# Note: Doesn't do well with handling date-time conversions
# e.g.
# sqlite: SELECT DATETIME(MAX(timestamp), '-5 minute') FROM demo WHERE isin_id = 'VM88109EGG92'
# postgres: SELECT MAX(timestamp) - INTERVAL '5 minutes' FROM demo where isin_id='VM88109EGG92'
# Reference ticket: https://github.com/tobymao/sqlglot/issues/2011
result = res
try:
result = sqlglot.transpile(res, identify=True, write="sqlite")[0]
except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e:
logger.info("We did the best we could, there might be still be some error:\n")
logger.info(f"Realized query so far:\n {res}")
return result

def task_formatter(self, input_task: str):
# Generated format
Expand Down
9 changes: 6 additions & 3 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,18 @@ def generate_text_embeddings(model_path: str, x, model_obj=None, batch_size: int
return res, model_obj


def filter_samples(input_q: str, probable_qs: list, model_path: str, model_obj=None, threshold: float = 0.45):
def filter_samples(
input_q: str, probable_qs: list, model_path: str, model_obj=None, threshold: float = 0.80, device="auto"
):
# Only consider the questions, note: this might change in future.
_inq = ("# query: " + input_q).strip().lower()
logger.debug(f"Input questions: {_inq}")
question_embeddings, model_obj = generate_text_embeddings(model_path, x=[_inq], model_obj=model_obj, device="cpu")
_device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device
question_embeddings, model_obj = generate_text_embeddings(model_path, x=[_inq], model_obj=model_obj, device=_device)

input_pqs = [_se.split("# answer")[0].strip().lower() for _se in probable_qs]
logger.debug(f"Probable context: {input_pqs}")
embeddings, model_obj = generate_text_embeddings(model_path, x=input_pqs, model_obj=model_obj, device="cpu")
embeddings, model_obj = generate_text_embeddings(model_path, x=input_pqs, model_obj=model_obj, device=_device)
res = {}
_scores = []
for idx, _se in enumerate(embeddings):
Expand Down

0 comments on commit 97b2c78

Please sign in to comment.