Skip to content

Commit

Permalink
Load default model during app init #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Oct 17, 2023
1 parent b0d0261 commit 8b50af6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
4 changes: 2 additions & 2 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
is_regenerate_with_options: bool = False,
):
self.db_url = db_url
self.engine = create_engine(db_url)
self.sql_database = SQLDatabase(self.engine)
self.engine = create_engine(db_url) if db_url else None
self.sql_database = SQLDatabase(self.engine) if self.engine else None
self.context_builder = None
self.data_input_path = _check_file_info(data_input_path)
self.sample_queries_path = sample_queries_path
Expand Down
19 changes: 9 additions & 10 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from sentence_transformers import SentenceTransformer
from sidekick.logger import logger
from sklearn.metrics.pairwise import cosine_similarity
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig


def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None):
Expand Down Expand Up @@ -269,9 +268,9 @@ def get_table_keys(file_path: str, table_key: str):
return res, data


def is_resource_low():
free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3)
def is_resource_low(device_index: int = 0):
free_in_GB = int(torch.cuda.mem_get_info(device_index)[0] / 1024**3)
total_memory = int(torch.cuda.get_device_properties(device_index).total_memory / 1024**3)
logger.info(f"Total Memory: {total_memory}GB")
logger.info(f"Free GPU memory: {free_in_GB}GB")
off_load = True
Expand All @@ -296,20 +295,21 @@ def load_causal_lm_model(
}
model_name = model_choices_map[model_type]
logger.info(f"Loading model: {model_name}")
device_index = 0
# Load h2oGPT.SQL model
device = {"": 0} if torch.cuda.is_available() else "cpu" if device == "auto" else device
device = {"": device_index} if torch.cuda.is_available() else "cpu" if device == "auto" else device
total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3)
free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
logger.info(f"Free GPU memory: {free_in_GB}GB")
n_gpus = torch.cuda.device_count()
logger.info(f"Total GPUs: {n_gpus}")
_load_in_8bit = load_in_8bit

# 22GB (Least requirement on GPU) is a magic number for the current model size.
if off_load and re_generate and total_memory < 22:
# To prevent the system from crashing in-case memory runs low.
# TODO: Performance when offloading to CPU.
max_memory = f"{4}GB"
max_memory = {i: max_memory for i in range(n_gpus)}
max_memory = {device_index: f"{4}GB"}
logger.info(f"Max Memory: {max_memory}, offloading to CPU")
with init_empty_weights():
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_path, offload_folder=cache_path)
Expand All @@ -322,8 +322,7 @@ def load_causal_lm_model(
_load_in_8bit = True
load_in_4bit = False
else:
max_memory = f"{int(free_in_GB)-2}GB"
max_memory = {i: max_memory for i in range(n_gpus)}
max_memory = {device_index: f"{int(free_in_GB)-2}GB"}
_offload_state_dict = False
_llm_int8_enable_fp32_cpu_offload = False

Expand Down
21 changes: 21 additions & 0 deletions ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from h2o_wave import Q, app, data, handle_on, main, on, ui
from h2o_wave.core import expando_to_dict
from sidekick.prompter import db_setup_api, query_api
from sidekick.query import SQLGenerator
from sidekick.utils import get_table_keys, save_query, setup_dir, update_tables

# Load the config file and initialize required paths
Expand Down Expand Up @@ -197,6 +198,8 @@ async def chatbot(q: Q):
question = f"{q.args.chatbot}"
logging.info(f"Question: {question}")

if q.args.table_dropdown or q.args.model_choice_dropdown:
return
# For regeneration, currently there are 2 modes
# 1. Quick fast approach by throttling the temperature
# 2. "Try harder mode (THM)" Slow approach by using the diverse beam search
Expand Down Expand Up @@ -662,6 +665,23 @@ async def on_event(q: Q):
return event_handled


def on_startup():
logging.info("SQL-Assistant started!")
logging.info(f"Initializing default model")

_ = SQLGenerator(
None,
None,
model_name="h2ogpt-sql-sqlcoder2",
job_path=base_path,
data_input_path="",
sample_queries_path="",
is_regenerate_with_options="",
is_regenerate="",
)
return


@app("/", on_shutdown=on_shutdown)
async def serve(q: Q):
# Run only once per client connection.
Expand All @@ -670,6 +690,7 @@ async def serve(q: Q):
setup_dir(base_path)
await init(q)
q.client.initialized = True
on_startup()
logging.info("App initialized.")

# Handle routing.
Expand Down

0 comments on commit 8b50af6

Please sign in to comment.