Skip to content

Commit

Permalink
Cache column sample values for future use #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Jul 28, 2023
1 parent c1e2cc4 commit 061822c
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 16 deletions.
2 changes: 2 additions & 0 deletions sidekick/configs/data_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@
"Column Type": "",
"Sample Values": []
}

data_samples_template = "Column {column_name} contains values similar to {comma_separated_sample_values}."
15 changes: 15 additions & 0 deletions sidekick/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sqlalchemy
from pandasql import sqldf
from psycopg2.extras import Json
from sidekick.configs.data_template import data_samples_template
from sidekick.logger import logger
from sqlalchemy import create_engine
from sqlalchemy_utils import database_exists
Expand Down Expand Up @@ -84,6 +85,7 @@ def _extract_schema_info(self, schema_info_path=None):
with open(table_info_file, "w") as outfile:
schema_info_path = json.load(outfile)["schema_info_path"]
res = []
sample_values = []
try:
if Path(schema_info_path).exists():
with open(schema_info_path, "r") as in_file:
Expand All @@ -93,8 +95,21 @@ def _extract_schema_info(self, schema_info_path=None):
if "Column Name" in data and "Column Type" in data:
col_name = data["Column Name"]
col_type = data["Column Type"]
# if column has sample values, save in cache for future use.
if "Sample Values" in data:
_sample_values = data["Sample Values"]
_ds = data_samples_template.format(
column_name=col_name, comma_separated_sample_values=",".join(_sample_values)
)
sample_values.append(_ds)
_new_samples = f"{col_name} {col_type}"
res.append(_new_samples)
if len(sample_values):
# cache it for future use
with open(
f"{self.base_path}/var/lib/tmp/data/{self._table_name}_column_values.json", "w"
) as outfile:
json.dump(sample_values, outfile, indent=2, sort_keys=False)
except ValueError as ve:
logger.error(f"Error in reading table context file: {ve}")
pass
Expand Down
20 changes: 15 additions & 5 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index.indices.struct_store import SQLContextContainerBuilder
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT,
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, QUERY_PROMPT, NSQL_QUERY_PROMPT,
TASK_PROMPT)
from sidekick.logger import logger
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
Expand Down Expand Up @@ -50,6 +50,16 @@ def __init__(
self.openai_key = openai_key
self.content_queries = None

def load_table_info(self):
# Read table_info.jsonl
table_info_file = f"{self.path}/var/lib/tmp/data/table_context.json"
def setup(self):

# Load the table information
self.load_table_info()



def build_index(self, persist: bool = True):
# Below re-assignment of the OPENAI API key is weird but without that, it throws an error.
os.environ["OPENAI_API_KEY"] = self.openai_key
Expand Down Expand Up @@ -271,16 +281,16 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "s
contextual_context.append(f"{_item}: {_val}")

print("Filtering Question/Query pairs")
_samples = filter_samples(input_question, probable_qs=sample_pairs,
model_path=local_model_path, threshold=0.90)
_samples = filter_samples(input_question, probable_qs=context_queries,
model_path='', threshold=0.90)

# If QnA pairs > 5, we keep only 5 of them for focused context
if len(_samples) > 5:
_samples = _samples[0:5][::-1]
qna_samples = '\n'.join(_samples)

contextual_context_val = ', '.join(contextual_context)

column_names = [str(_c) for _c in self.sql_database.get_column_names(table_name[0])]
if len(_samples) > 2:
# Check for the columns in the QnA samples provided, if exists keep them
context_columns = [_c for _c in column_names if _c.lower() in qna_samples.lower()]
Expand All @@ -290,7 +300,7 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "s
relevant_columns = context_columns if len(context_columns) > 0 else column_names
_data_info = ', '.join(relevant_columns)

query = prompt_template.format(table_name=_table_name, data_info=_data_info, data_info_detailed=data_samples,
query = NSQL_QUERY_PROMPT.format(table_name=table_name, data_info=_data_info, data_info_detailed=data_samples,
sample_queries=qna_samples, context=contextual_context_val,
question_txt=input_question)

Expand Down
56 changes: 45 additions & 11 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from pathlib import Path
from typing import Optional

import torch
import numpy as np
import pandas as pd
from pandasql import sqldf
from sentence_transformers import SentenceTransformer
from InstructorEmbedding import INSTRUCTOR
from sidekick.logger import logger
from sklearn.metrics.pairwise import cosine_similarity

Expand Down Expand Up @@ -37,6 +39,38 @@ def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, devic
return all_res


def generate_text_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = 'cpu'):
# Reference:
# 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models
# 2. Evaluation result: https://www.sbert.net/_static/html/models_en_sentence_embeddings.html
# 3. Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
# 4. Reference: https://huggingface.co/spaces/mteb/leaderboard
# Maps sentence & paragraphs to a 384 dimensional dense vector space.
model_name_path = f"{model_path}/text_embedding/instructor-large"
current_torch_home = os.environ.get("TORCH_HOME", "")
if Path(model_name_path).is_dir():
is_empty = not any(Path(model_name_path).iterdir())
if is_empty:
# Download n cache at the specified location
os.environ["TORCH_HOME"] = model_path
model_name_path = "hkunlp/instructor-large"
sentence_model = INSTRUCTOR(model_name_path, device=device)
if device != 'cuda':
# Issue https://github.com/pytorch/pytorch/issues/69364
# # In the initial experimentation, quantized model is generates slightly better results
_model = torch.quantization.quantize_dynamic(
sentence_model, {torch.nn.Linear}, dtype=torch.qint8)
else:
_model = sentence_model
_sentences = [['Represent the Financial question for retrieving duplicate examples: ', _item] for _item in x]

res = _model.encode(_sentences)
del sentence_model
del _model
os.environ["TORCH_HOME"] = current_torch_home
return res


def filter_samples(input_q: str, probable_qs: list, model_path: str, threshold: float = 0.45):
# Only consider the questions, note: this might change in future.
_inq = ("# query: " + input_q).strip().lower()
Expand Down Expand Up @@ -102,21 +136,21 @@ def read_sample_pairs(input_path: str, model_name: str = "nsql"):
df = df.reset_index(drop=True)

# NSQL format
if model_name != 'nsql':
if model_name != "nsql":
# Open AI format
# Convert frame to below format
# [
# "# query": ""
# "# answer": ""
# ]
# Convert frame to below format
# [
# "# query": ""
# "# answer": ""
# ]
res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list()
else:
# Convert frame to below format
# [
# "Question": <question_text>
# "Answer":
# <response_text>
# ]
# [
# "Question": <question_text>
# "Answer":
# <response_text>
# ]
res = df.apply(lambda row: f"Question: {row['query']}\nAnswer:\n{row['answer']}", axis=1).to_list()
return res

Expand Down

0 comments on commit 061822c

Please sign in to comment.