Skip to content

Commit

Permalink
Fix initialization #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Aug 7, 2023
1 parent dd6ab34 commit a7dfae3
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
import sqlglot
import torch
from langchain import OpenAI
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index.indices.struct_store import SQLContextContainerBuilder
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT,
NSQL_QUERY_PROMPT, QUERY_PROMPT,
TASK_PROMPT)
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, NSQL_QUERY_PROMPT, QUERY_PROMPT, TASK_PROMPT
from sidekick.logger import logger
from sidekick.utils import filter_samples, read_sample_pairs, remove_duplicates
from sqlalchemy import create_engine
Expand Down Expand Up @@ -271,7 +268,7 @@ def generate_sql(
# Load h2oGPT.NSQL model
device = {"": 0} if torch.cuda.is_available() else "cpu"
if self.model is None:
self.tokenizer = tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device)
self.tokenizer = AutoTokenizer.from_pretrained("NumbersStation/nsql-6B", device_map=device)
self.model = AutoModelForCausalLM.from_pretrained(
"NumbersStation/nsql-6B", device_map=device, load_in_8bit=True
)
Expand Down Expand Up @@ -362,7 +359,7 @@ def generate_sql(
)

logger.debug(f"Query Text:\n {query}")
inputs = tokenizer([query], return_tensors="pt")
inputs = self.tokenizer([query], return_tensors="pt")
input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
# Generate SQL
random_seed = random.randint(0, 50)
Expand Down

0 comments on commit a7dfae3

Please sign in to comment.