Skip to content

Commit

Permalink
Decode and generate #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Jul 28, 2023
1 parent 061822c commit 777d63c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
2 changes: 1 addition & 1 deletion sidekick/configs/.env.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[OPENAI]
OPENAI_API_KEY = ""
MODEL_NAME = "nsql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003
MODEL_NAME = "h2ogpt-sql" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003

[LOCAL_DB_CONFIG]
HOST_NAME = "localhost"
Expand Down
45 changes: 28 additions & 17 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import openai
import sqlglot
import toml
import random
import torch
from langchain import OpenAI
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
Expand Down Expand Up @@ -50,15 +50,15 @@ def __init__(
self.openai_key = openai_key
self.content_queries = None

def load_table_info(self):
# Read table_info.jsonl
table_info_file = f"{self.path}/var/lib/tmp/data/table_context.json"
def setup(self):

# Load the table information
self.load_table_info()


def load_column_samples(self, tables: list):
# TODO: Maybe we add table name as a member variable
# Load column values if they exists
examples = {}
for _t in tables:
f_p = f"{self.path}/var/lib/tmp/data/{_t}_column_values.json"
with open(f_p, "r") as f:
examples[_t] = json.load(f)
return examples

def build_index(self, persist: bool = True):
# Below re-assignment of the OPENAI API key is weird but without that, it throws an error.
Expand Down Expand Up @@ -206,12 +206,12 @@ def generate_tasks(self, table_names: list, input_question: str):
raise se


def generate_sql(self, table_name: list, input_question: str, _dialect: str = "sqlite", model_name: str = "nsql"):
def generate_sql(self, table_name: list, input_question: str, _dialect: str = "sqlite", model_name: str = "h2ogpt-sql"):
context_file = f"{self.path}/var/lib/tmp/data/context.json"
additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {}
context_queries = self.content_queries

if model_name != "nsql":
if model_name != "h2ogpt-sql":
_tasks = self.task_formatter(self._tasks)

# TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate.
Expand Down Expand Up @@ -263,6 +263,7 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "s
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B")

data_samples = context_queries
data_samples_list = self.load_column_samples(table_name)

_context = {
"if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL",
Expand All @@ -271,20 +272,19 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "s

filtered_context = filter_samples(input_question, probable_qs=list(_context.keys()),
model_path='', threshold=0.845)

print(f"Filter Context: {filtered_context}")
logger.debug(f"Filter Context: {filtered_context}")

contextual_context = []
for _item in filtered_context:
_val = _context.get(_item, None)
if _val:
contextual_context.append(f"{_item}: {_val}")

print("Filtering Question/Query pairs")
logger.info("Filtering Question/Query pairs")
_samples = filter_samples(input_question, probable_qs=context_queries,
model_path='', threshold=0.90)

# If QnA pairs > 5, we keep only 5 of them for focused context
# If QnA pairs > 5, we keep top 5 for focused context
if len(_samples) > 5:
_samples = _samples[0:5][::-1]
qna_samples = '\n'.join(_samples)
Expand All @@ -304,7 +304,18 @@ def generate_sql(self, table_name: list, input_question: str, _dialect: str = "s
sample_queries=qna_samples, context=contextual_context_val,
question_txt=input_question)

input_ids = tokenizer(query, return_tensors="pt").input_ids
inputs = tokenizer(query, return_tensors="pt")

# Generate SQL
random_seed = random.randint(0, 50)
torch.manual_seed(random_seed)

# Greedy search for quick response
output = model.generate(**inputs, max_new_tokens=300, temperature=0.5, output_scores=True,
return_dict_in_generate=True)
_res = tokenizer.decode(output[0][0], skip_special_tokens=True)
# Below is a pre-caution in-case of an error in table name during generation
res = _res.replace('table_name', table_name[0])
return res

def task_formatter(self, input_task: str):
Expand Down
11 changes: 5 additions & 6 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, devic
return all_res


def generate_text_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = 'cpu'):
def generate_text_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = "cpu"):
# Reference:
# 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models
# 2. Evaluation result: https://www.sbert.net/_static/html/models_en_sentence_embeddings.html
Expand All @@ -55,14 +55,13 @@ def generate_text_embeddings(model_path: str, x, batch_size: int = 32, device: O
os.environ["TORCH_HOME"] = model_path
model_name_path = "hkunlp/instructor-large"
sentence_model = INSTRUCTOR(model_name_path, device=device)
if device != 'cuda':
if device != "cuda":
# Issue https://github.com/pytorch/pytorch/issues/69364
# # In the initial experimentation, quantized model is generates slightly better results
_model = torch.quantization.quantize_dynamic(
sentence_model, {torch.nn.Linear}, dtype=torch.qint8)
_model = torch.quantization.quantize_dynamic(sentence_model, {torch.nn.Linear}, dtype=torch.qint8)
else:
_model = sentence_model
_sentences = [['Represent the Financial question for retrieving duplicate examples: ', _item] for _item in x]
_sentences = [["Represent the Financial question for retrieving duplicate examples: ", _item] for _item in x]

res = _model.encode(_sentences)
del sentence_model
Expand Down Expand Up @@ -136,7 +135,7 @@ def read_sample_pairs(input_path: str, model_name: str = "nsql"):
df = df.reset_index(drop=True)

# NSQL format
if model_name != "nsql":
if model_name != "h2ogpt-sql":
# Open AI format
# Convert frame to below format
# [
Expand Down

0 comments on commit 777d63c

Please sign in to comment.