Skip to content

Commit

Permalink
Update system config to support 'defog/sqlcoder-7b-2' #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Feb 7, 2024
1 parent e4b6c47 commit 008ae9a
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 48 deletions.
4 changes: 2 additions & 2 deletions sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,8 @@ def ask(
_err = _tmp[0].split("Error occurred:")[1] if len(_tmp) > 0 else None
env_url = os.environ["H2OGPTE_URL"]
env_key = os.environ["H2OGPTE_API_TOKEN"]
corr_sql = sql_g.self_correction(input_query=_val, error_msg=_err, remote_url=env_url, client_key=env_key)
q_res, err = DBConfig.execute_query(query=corr_sql)
corrected_sql = sql_g.self_correction(input_query=_val, error_msg=_err, remote_url=env_url, client_key=env_key)
q_res, err = DBConfig.execute_query(query=corrected_sql)
if not 'Error occurred'.lower() in str(err).lower():
err = None
attempt += 1
Expand Down
64 changes: 30 additions & 34 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
import sys
from pathlib import Path

import requests
import numpy as np
import openai
import sqlglot
Expand Down Expand Up @@ -247,30 +247,30 @@ def self_correction(self, error_msg, input_query, remote_url, client_key):
_res = input_query
self_correction_model = os.getenv("SELF_CORRECTION_MODEL", "h2oai/h2ogpt-4096-llama2-70b-chat")
if "h2ogpt-" in self_correction_model:
if remote_url and client_key:
try:
from h2ogpte import H2OGPTE
client = H2OGPTE(address=remote_url, api_key=client_key)
text_completion = client.answer_question(
system_prompt=system_prompt,
text_context_list=[],
question=user_prompt,
llm=self_correction_model)
except Exception as e:
logger.info(f"H2OGPTE client is not configured, reach out if API key is needed, {e}. Attempting to use H2OGPT client")
# Make attempt to use h2ogpt client with OSS access
_api_key = client_key if client_key else "***"
client_args = dict(base_url=remote_url, api_key=_api_key, timeout=20.0)
query_msg = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
h2ogpt_base_client = OpenAI(**client_args)
h2ogpt_base_client.with_options(max_retries=3).chat.completions.create(
model=self_correction_model,
messages=query_msg,
max_tokens=512,
temperature=0.5,
stop="```",
seed=42)
text_completion = completion.choices[0].message
if remote_url and client_key and remote_url != "" and client_key != "":
from h2ogpte import H2OGPTE
client = H2OGPTE(address=remote_url, api_key=client_key)
text_completion = client.answer_question(
system_prompt=system_prompt,
text_context_list=[],
question=user_prompt,
llm=self_correction_model)
else:
logger.info(f"H2OGPTE client is not configured, attempting to use OSS H2OGPT client")
h2o_client_url = os.getenv("H2OGPT_BASE_URL", None)
h2o_client_key = os.getenv("H2OGPT_BASE_API_TOKEN", None)
# Make attempt to use h2ogpt client with OSS access
client_args = dict(base_url=h2o_client_url, api_key=h2o_client_key, timeout=20.0)
query_msg = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
h2ogpt_base_client = OpenAI(**client_args)
completion = h2ogpt_base_client.with_options(max_retries=3).chat.completions.create(
model=self_correction_model,
messages=query_msg,
max_tokens=512,
temperature=0.5,
stop="```",
seed=42)
text_completion = completion.choices[0].message
_response = text_completion.content
elif 'gpt-3.5' in self_correction_model.lower() or 'gpt-4' in self_correction_model.lower():
# Check if the API key is set, else inform user
Expand Down Expand Up @@ -390,7 +390,7 @@ def generate_sql(
self,
table_names: list,
input_question: str,
model_name: str = "h2ogpt-sql-nsql-llama-2-7B",
model_name: str = "h2ogpt-sql-sqlcoder-7b-2",
):
# TODO: Update needed to support multiple tables
table_name = str(table_names[0].replace(" ", "_")).lower()
Expand Down Expand Up @@ -464,7 +464,7 @@ def generate_sql(
remote_h2ogpt_base_url = os.environ.get("H2OGPT_URL", None)
if model_name == 'h2ogpt-sql-sqlcoder-34b-alpha':
remote_h2ogpt_base_url = f"{remote_h2ogpt_base_url}:5000/v1"
elif model_name == 'h2ogpt-sql-sqlcoder2':
elif model_name == 'h2ogpt-sql-sqlcoder-7b-2':
remote_h2ogpt_base_url = f"{remote_h2ogpt_base_url}:5001/v1"
elif model_name == 'h2ogpt-sql-nsql-llama-2-7B':
remote_h2ogpt_base_url = f"{remote_h2ogpt_base_url}:5002/v1"
Expand Down Expand Up @@ -637,7 +637,7 @@ def generate_sql(
# Greedy decoding, for fast response
# Reset temperature to 0.5
current_temperature = 0.5
if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
if model_name == "h2ogpt-sql-sqlcoder-7b-2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha")
query_txt = [{"role": "user", "content": query},]
logger.debug(f"Generation with default temperature : {current_temperature}")
Expand Down Expand Up @@ -667,7 +667,7 @@ def generate_sql(
# throttle temperature for different result
logger.info("Regeneration requested on previous query ...")
logger.debug(f"Selected temperature for fast regeneration : {random_temperature}")
if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
if model_name == "h2ogpt-sql-sqlcoder-7b-2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha")
query_txt = [{"role": "user", "content": query},]
completion = self.h2ogpt_client.with_options(max_retries=3).chat.completions.create(
Expand All @@ -692,7 +692,7 @@ def generate_sql(
logger.debug(f"Temperature saved: {self.current_temps[model_name]}")
else:
logger.info("Regeneration with options requested on previous query ...")
if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
if model_name == "h2ogpt-sql-sqlcoder-7b-2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B":
logger.info("Generating diverse options, not enabled for remote models")
m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha")
query_txt = [{"role": "user", "content": query},]
Expand Down Expand Up @@ -804,10 +804,6 @@ def generate_sql(

h2o_client_url = os.getenv("H2OGPT_API_TOKEN", None)
h2o_client_key = os.getenv("H2OGPTE_API_TOKEN", None)
if not h2o_client_url or not h2o_client_key:
logger.info(f"H2OGPTE client is not configured, attempting to use OSS H2OGPT client")
h2o_client_url = os.getenv("H2OGPT_BASE_URL", None)
h2o_client_key = os.getenv("H2OGPT_BASE_API_TOKEN", None)
try:
result = self.self_correction(input_query=res, error_msg=str(ex_traceback), remote_url=h2o_client_url, client_key=h2o_client_key)
except Exception as se:
Expand Down
7 changes: 5 additions & 2 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)


REMOTE_LLMS = ["h2ogpt-sql-sqlcoder-34b-alpha", "h2ogpt-sql-sqlcoder2", "h2ogpt-sql-nsql-llama-2-7B",
"gpt-3.5-turbo", "gpt-4-8k", "gpt-4-1106-preview-128k"]
"h2ogpt-sql-sqlcoder-7b-2", "gpt-3.5-turbo", "gpt-4-8k", "gpt-4-1106-preview-128k"]

# clone of models from https://huggingface.co/models
# suffix `h2ogpt-sql-` is added to avoid conflict with the original models (we haven't done any changes to the original models yet)
Expand All @@ -33,6 +34,7 @@
"h2ogpt-sql-sqlcoder-34b-alpha-4bit": "defog/sqlcoder-34b-alpha",
"h2ogpt-sql-nsql-llama-2-7B-4bit": "NumbersStation/nsql-llama-2-7B",
"h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
"h2ogpt-sql-sqlcoder-7b-2": "defog/sqlcoder-7b-2",
"h2ogpt-sql-sqlcoder-34b-alpha": "defog/sqlcoder-34b-alpha",
"h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B",
"gpt-3.5-turbo": "gpt-3.5-turbo-1106",
Expand All @@ -47,6 +49,7 @@
# "h2ogpt-sql-nsql-llama-2-7B-4bit": "NumbersStation/nsql-llama-2-7B",
# "h2ogpt-sql-sqlcoder2": "defog/sqlcoder2",
"h2ogpt-sql-sqlcoder-34b-alpha": "defog/sqlcoder-34b-alpha",
"h2ogpt-sql-sqlcoder-7b-2": "defog/sqlcoder-7b-2",
"h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B"
}

Expand Down Expand Up @@ -549,7 +552,7 @@ def check_vulnerability(input_query: str):
# Currently, only support only for models as an endpoints
logger.debug(f"Requesting additional scan using configured models")
h2ogpt_client_url = h2ogpt_client_key = None
h2ogpte_client_url = os.getenv("H2OGPT_API_TOKEN", None)
h2ogpte_client_url = os.getenv("H2OGPTE_URL", None)
h2ogpte_client_key = os.getenv("H2OGPTE_API_TOKEN", None)
if not h2ogpte_client_url or not h2ogpte_client_key:
logger.info(f"H2OGPTE client is not configured, attempting to use OSS H2OGPT client")
Expand Down
4 changes: 1 addition & 3 deletions start.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@ def setup_dir(base_path: str):
# Model 1:
print(f"Download model 1...")
snapshot_download(repo_id="NumbersStation/nsql-llama-2-7B", cache_dir=f"{base_path}/models/")

# Model 2:
print(f"Download model 2...")
snapshot_download(repo_id="defog/sqlcoder2", cache_dir=f"{base_path}/models/")
# Model 3:
print(f"Download model 3...")
snapshot_download(repo_id="defog/sqlcoder-34b-alpha", cache_dir=f"{base_path}/models/")

print(f"Download embedding model...")
Expand Down
43 changes: 38 additions & 5 deletions tests/test_self_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,15 @@ def test_input2():
assert result != input_q
assert error is None

@pytest.mark.parametrize("input_q, debugger", [("""SELECT CONCAT("age", " ", "heart_rate") AS "age_heart_rate" FROM "sleep_health_and_lifestyle" ORDER BY "age_heart_rate" DESC LIMIT 100
""", "h2oai/h2ogpt-4096-llama2-70b-chat"),
@pytest.mark.parametrize("input_q, debugger, base_model", [("""SELECT CONCAT("age", " ", "heart_rate") AS "age_heart_rate" FROM "sleep_health_and_lifestyle" ORDER BY "age_heart_rate" DESC LIMIT 100
""", "h2oai/h2ogpt-4096-llama2-70b-chat", "h2ogpt-sql-sqlcoder-34b-alpha"),
("""SELECT CONCAT("age", " ", "heart_rate") AS "age_heart_rate" FROM "sleep_health_and_lifestyle" ORDER BY "age_heart_rate" DESC LIMIT 100
""", "gpt-3.5-turbo")])
def test_input3(input_q, debugger):
""", "h2oai/h2ogpt-4096-llama2-70b-chat", "h2ogpt-sql-sqlcoder-34b-alpha"),
("""SELECT CONCAT("age", " ", "heart_rate") AS "age_heart_rate" FROM "sleep_health_and_lifestyle" ORDER BY "age_heart_rate" DESC LIMIT 100
""", "h2oai/h2ogpt-4096-llama2-70b-chat", "h2ogpt-sql-sqlcoder-7b-2"),
("""SELECT CONCAT("age", " ", "heart_rate") AS "age_heart_rate" FROM "sleep_health_and_lifestyle" ORDER BY "age_heart_rate" DESC LIMIT 100
""", "gpt-3.5-turbo", "h2ogpt-sql-sqlcoder-7b-2")])
def test_input3(input_q, debugger, base_model):
# There is no CONCAT function in SQLite
os.environ["SELF_CORRECTION_MODEL"] = debugger
question = f"Execute SQL:\n{input_q}"
Expand All @@ -155,7 +159,7 @@ def test_input3(input_q, debugger):
sample_queries_path=None,
table_name=table_name,
is_command=False,
model_name="h2ogpt-sql-sqlcoder-34b-alpha",
model_name=base_model,
is_regenerate=False,
is_regen_with_options=False,
execute_query=True,
Expand All @@ -166,3 +170,32 @@ def test_input3(input_q, debugger):
)
assert error == None
assert res != None

# Fixing correlation function needs further investigation
@pytest.mark.parametrize("input_q, debugger, base_model", [
("""Correlation between sleep duration and quality of sleep""", "h2oai/h2ogpt-4096-llama2-70b-chat", "h2ogpt-sql-sqlcoder-34b-alpha"),
("""Correlation between sleep duration and quality of sleep""", "h2oai/h2ogpt-4096-llama2-70b-chat", "h2ogpt-sql-sqlcoder-7b-2"),
("""Correlation between sleep duration and quality of sleep""", "gpt-3.5-turbo", "h2ogpt-sql-sqlcoder-7b-2"),
("""Correlation between sleep duration and quality of sleep" AS "s" LIMIT 100?""", "gpt-4-8k", "h2ogpt-sql-sqlcoder-7b-2")])
def test_input4(input_q, debugger, base_model):
# There is no CONCAT function in SQLite
os.environ["SELF_CORRECTION_MODEL"] = debugger
question = f"Execute SQL:\n{input_q}"
print(f"Model Name/Debugger: {base_model}/{debugger}")
res, _, error = ask(
question=question,
table_info_path=table_info_path,
sample_queries_path=None,
table_name=table_name,
is_command=False,
model_name=base_model,
is_regenerate=False,
is_regen_with_options=False,
execute_query=True,
local_base_path=base_path,
debug_mode=False,
guardrails=False,
self_correction=True
)
assert error == None
assert res != None
5 changes: 3 additions & 2 deletions ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

# env variables
env_settings = toml.load(f"{app_base_path}/sidekick/configs/env.toml")
DEFAULT_MODEL = "h2ogpt-sql-sqlcoder-34b-alpha"

# Pre-initialize the models for faster response
def initialize_models():
Expand All @@ -41,7 +42,7 @@ def initialize_models():
_ = SQLGenerator(
db_url=None,
openai_key=None,
model_name="h2ogpt-sql-sqlcoder-34b-alpha",
model_name=DEFAULT_MODEL,
job_path=base_path,
data_input_path="",
sample_queries_path="",
Expand Down Expand Up @@ -135,7 +136,7 @@ async def chat(q: Q):

MODEL_CHOICE_MAP = q.client.model_choices
model_choices = [ui.choice(_key, _key) for _key in MODEL_CHOICE_MAP.keys()]
q.client.model_choice_dropdown = q.args.model_choice_dropdown = "h2ogpt-sql-sqlcoder-34b-alpha"
q.client.model_choice_dropdown = q.args.model_choice_dropdown = DEFAULT_MODEL

task_choices = [ui.choice("q_a", "Ask Questions"), ui.choice("sqld", "Debugging")]
q.client.task_choice_dropdown = q.args.task_dropdown = "q_a"
Expand Down

0 comments on commit 008ae9a

Please sign in to comment.