Skip to content

Commit

Permalink
Load quantized version of the model for faster inferrence #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Aug 4, 2023
1 parent 83ab1c2 commit dd6ab34
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
5 changes: 2 additions & 3 deletions sidekick/configs/.env.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ LOG-LEVEL = "INFO"
DB_TYPE = "sqlite"

[TABLE_INFO]
TABLE_INFO_PATH = "/examples/test/table_info.jsonl"
TABLE_SAMPLES_PATH = "/examples/test/masked_data_and_columns.csv"
TABLE_INFO_PATH = "/examples/demo/table_info.jsonl"
TABLE_SAMPLES_PATH = "/examples/demo/demo_data.csv"
TABLE_NAME = "demo"
DB_TYPE = "sqlite"
11 changes: 5 additions & 6 deletions sidekick/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def _extract_schema_info(self, schema_info_path=None):
if "Column Name" in data and "Column Type" in data:
col_name = data["Column Name"]
col_type = data["Column Type"]
if col_type.lower() == "text":
col_type = col_type + " COLLATE NOCASE"
# if column has sample values, save in cache for future use.
if "Sample Values" in data:
_sample_values = data["Sample Values"]
Expand All @@ -116,9 +118,7 @@ def _extract_schema_info(self, schema_info_path=None):
return res

def create_table(self, schema_info_path=None, schema_info=None):
engine = create_engine(
self._url, isolation_level="AUTOCOMMIT"
)
engine = create_engine(self._url, isolation_level="AUTOCOMMIT")
self._engine = engine
if self.schema_info is None:
if schema_info is not None:
Expand All @@ -139,9 +139,7 @@ def create_table(self, schema_info_path=None, schema_info=None):
return

def has_table(self):
engine = create_engine(
self._url
)
engine = create_engine(self._url)

return sqlalchemy.inspect(engine).has_table(self.table_name)

Expand Down Expand Up @@ -181,6 +179,7 @@ def execute_query_db(self, query=None, n_rows=100):

# Create a connection
connection = engine.connect()
logger.debug(f"Executing query:\n {query}")
result = connection.execute(query)

# Process the query results
Expand Down
1 change: 1 addition & 0 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def update_context():
@click.option("--table-info-path", "-t", help="Table info path", default=None)
@click.option("--sample-queries", "-s", help="Samples path", default=None)
def query(question: str, table_info_path: str, sample_queries: str):
"""Asks question and returns SQL."""
query_api(question=question, table_info_path=table_info_path, sample_queries=sample_queries, is_command=True)


Expand Down
28 changes: 19 additions & 9 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(
self._tasks = None
self.openai_key = openai_key
self.content_queries = None
self.model = None # Used for local LLMs
self.tokenizer = None # Used for local tokenizer

def load_column_samples(self, tables: list):
# TODO: Maybe we add table name as a member variable
Expand Down Expand Up @@ -267,8 +269,12 @@ def generate_sql(
logger.info(f"Realized query so far:\n {res}")
else:
# Load h2oGPT.NSQL model
tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B")
model = AutoModelForCausalLM.from_pretrained("NumbersStation/nsql-6B")
device = {"": 0} if torch.cuda.is_available() else "cpu"
if self.model is None:
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device)
self.model = AutoModelForCausalLM.from_pretrained(
"NumbersStation/nsql-6B", device_map=device, load_in_8bit=True
)

# TODO Update needed for multiple tables
columns_w_type = (
Expand Down Expand Up @@ -321,8 +327,8 @@ def generate_sql(
logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}")
# If QnA pairs > 5, we keep top 5 for focused context
_samples = filtered_context
if len(filtered_context) > 5:
_samples = filtered_context[0:5][::-1]
if len(filtered_context) > 3:
_samples = filtered_context[0:3][::-1]
qna_samples = "\n".join(_samples)

contextual_context_val = ", ".join(contextual_context)
Expand Down Expand Up @@ -357,24 +363,28 @@ def generate_sql(

logger.debug(f"Query Text:\n {query}")
inputs = tokenizer([query], return_tensors="pt")
input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
# Generate SQL
random_seed = random.randint(0, 50)
torch.manual_seed(random_seed)

# Greedy search for quick response
output = model.generate(
**inputs,
self.model.eval()
device_type = "cuda" if torch.cuda.is_available() else "cpu"
output = self.model.generate(
**inputs.to(device_type),
max_new_tokens=300,
temperature=0.5,
output_scores=True,
return_dict_in_generate=True,
)

generated_tokens = output.sequences[:, input_length:]
_res = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
_res = self.tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
# Below is a pre-caution in-case of an error in table name during generation
res = "SELECT" + _res.replace("table_name", table_names[0])
# COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
_temp = _res.replace("table_name", table_names[0]).split(";")[0]
res = "SELECT" + _temp + " COLLATE NOCASE;"
return res

def task_formatter(self, input_task: str):
Expand Down
2 changes: 1 addition & 1 deletion sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def filter_samples(input_q: str, probable_qs: list, model_path: str, model_obj=N
_scores.append(similarities_score[0][0])

sorted_res = sorted(res.items(), key=lambda x: x[1], reverse=True)
logger.info(f"Sorted context: {sorted_res}")
logger.debug(f"Sorted context: {sorted_res}")
return list(dict(sorted_res).keys()), model_obj


Expand Down

0 comments on commit dd6ab34

Please sign in to comment.