Skip to content

Commit

Permalink
Dynmaically offload state between GPU n CPU as needed #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Aug 22, 2023
1 parent 32bc877 commit 80cb1f4
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 25 deletions.
7 changes: 6 additions & 1 deletion sidekick/prompter.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,12 @@ def query_api(
table_info_path = _get_table_info(path)

sql_g = SQLGenerator(
db_url, api_key, job_path=base_path, data_input_path=table_info_path, sample_queries_path=sample_queries_path
db_url,
api_key,
job_path=base_path,
data_input_path=table_info_path,
sample_queries_path=sample_queries_path,
regenerate=is_regenerate,
)
if "h2ogpt-sql" not in model_name:
sql_g._tasks = sql_g.generate_tasks(table_names, question)
Expand Down
43 changes: 27 additions & 16 deletions sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,19 @@
import torch
import torch.nn.functional as F
from langchain import OpenAI
from llama_index import (GPTSimpleVectorIndex, GPTSQLStructStoreIndex,
LLMPredictor, ServiceContext, SQLDatabase)
from llama_index import GPTSimpleVectorIndex, GPTSQLStructStoreIndex, LLMPredictor, ServiceContext, SQLDatabase
from llama_index.indices.struct_store import SQLContextContainerBuilder
from sidekick.configs.prompt_template import (DEBUGGING_PROMPT,
NSQL_QUERY_PROMPT, QUERY_PROMPT,
TASK_PROMPT)
from sidekick.configs.prompt_template import DEBUGGING_PROMPT, NSQL_QUERY_PROMPT, QUERY_PROMPT, TASK_PROMPT
from sidekick.logger import logger
from sidekick.utils import (_check_file_info, filter_samples,
load_causal_lm_model, load_embedding_model,
read_sample_pairs, remove_duplicates)
from sidekick.utils import (
_check_file_info,
filter_samples,
load_causal_lm_model,
load_embedding_model,
read_sample_pairs,
remove_duplicates,
offload_state,
)
from sqlalchemy import create_engine
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand All @@ -35,13 +38,17 @@ def __new__(
data_input_path: str = "./table_info.jsonl",
sample_queries_path: str = "./samples.csv",
job_path: str = "./",
device: str = "cpu",
device: str = "auto",
regenerate: bool = False
):
offloading = offload_state()
if offloading and regenerate:
cls._instance = None
logger.info(f"Offloading state : {offloading}")
if cls._instance is None:
cls._instance = super().__new__(cls)

cls._instance.model, cls._instance.tokenizer = load_causal_lm_model(
model_name, cache_path=job_path, device=device
model_name, cache_path=f"{job_path}/models/", device=device, off_load=offloading
)
model_embed_path = f"{job_path}/models/sentence_transformers"
device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device
Expand All @@ -57,6 +64,7 @@ def __init__(
sample_queries_path: str = "./samples.csv",
job_path: str = "./",
device: str = "cpu",
regenerate: bool = False
):
self.db_url = db_url
self.engine = create_engine(db_url)
Expand All @@ -67,6 +75,7 @@ def __init__(
self.path = job_path
self._data_info = None
self._tasks = None
self.model_name = model_name
self.openai_key = openai_key
self.content_queries = None

Expand Down Expand Up @@ -294,12 +303,13 @@ def generate_sql(
else:
if self.model is None:
# Load h2oGPT.NSQL if not initialized self.model is None
device = {"": 0} if torch.cuda.is_available() else "cpu"
# https://github.com/pytorch/pytorch/issues/52291
_load_in_8bit = False if "cpu" in device else True
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, device_map=device)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map=device, load_in_8bit=_load_in_8bit
offloading = offload_state()
if offloading:
self.clear()
logger.info(f"Offloading state: {offloading}")
self.model, self.tokenizer = load_causal_lm_model(
self.model_name, cache_path=f"{self.path}/models/", device="auto", off_load=offloading
)

# TODO Update needed for multiple tables
Expand Down Expand Up @@ -448,6 +458,7 @@ def generate_sql(
num_beam_groups=5,
num_return_sequences=5,
output_scores=True,
do_sample=False,
diversity_penalty=1.0,
return_dict_in_generate=True,
)
Expand Down
60 changes: 52 additions & 8 deletions sidekick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from sentence_transformers import SentenceTransformer
from sidekick.logger import logger
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from accelerate import init_empty_weights, infer_auto_device_map


def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None):
Expand Down Expand Up @@ -125,14 +126,14 @@ def save_query(output_path: str, query, response, extracted_entity: Optional[dic


def setup_dir(base_path: str):
dir_list = ["var/lib/tmp/data", "var/lib/tmp/jobs", "var/lib/tmp/.cache"]
dir_list = ["var/lib/tmp/data", "var/lib/tmp/jobs", "var/lib/tmp/.cache", "models/weights"]
for _dl in dir_list:
p = Path(f"{base_path}/{_dl}")
if not p.is_dir():
p.mkdir(parents=True, exist_ok=True)


def update_tables(json_file_path:str, new_data:dict):
def update_tables(json_file_path: str, new_data: dict):
# Check if the JSON file exists
if os.path.exists(json_file_path):
try:
Expand Down Expand Up @@ -225,7 +226,7 @@ def execute_query_pd(query=None, tables_path=None, n_rows=100):
return res_df


def get_table_keys(file_path:str, table_key:str):
def get_table_keys(file_path: str, table_key: str):
res = []
if not os.path.exists(file_path):
logger.debug(f"File '{file_path}' does not exist.")
Expand All @@ -241,20 +242,63 @@ def get_table_keys(file_path:str, table_key:str):
return res, data


def load_causal_lm_model(model_name: str, cache_path: str, device: str, load_in_8bit: bool = True):
def offload_state():
free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3)
logger.info(f"Total Memory: {total_memory}")
logger.info(f"Free GPU memory: {free_in_GB}GB")
off_load = True
if int(free_in_GB) >= int(0.45 * total_memory):
off_load = False
return off_load


def load_causal_lm_model(
model_name: str, cache_path: str, device: str, load_in_8bit: bool = True, off_load: bool = False
):
try:
# Load h2oGPT.NSQL model
device = {"": 0} if torch.cuda.is_available() else "cpu" if device == "auto" else device
free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
total_memory = int(torch.cuda.get_device_properties(0).total_memory / 1024**3)
n_gpus = torch.cuda.device_count()

# 22GB (Least requirement on GPU) is a magic number for the current model size.
if off_load and total_memory < 22:
# TODO: Performance when offloading to CPU.
max_memory = f"{4}GB"
max_memory = {i: max_memory for i in range(n_gpus)}
logger.info(f"Max Memory: {max_memory}")
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_path, offload_folder=cache_path)

model = AutoModelForCausalLM.from_config(config)
device = infer_auto_device_map(model, max_memory=max_memory)
device["lm_head"] = 0
_offload_state_dict = True
_llm_int8_enable_fp32_cpu_offload = True
else:
max_memory = f"{int(free_in_GB)-2}GB"
max_memory = {i: max_memory for i in range(n_gpus)}
_offload_state_dict = False
_llm_int8_enable_fp32_cpu_offload = False

if load_in_8bit:
_load_in_8bit = False if "cpu" in device else True
else:
_load_in_8bit = False
cache_path = f"{cache_path}/models/"
logger.debug(f"Current device config: {device}")

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_path, device_map=device)
model = AutoModelForCausalLM.from_pretrained(
model_name, cache_dir=cache_path, device_map=device, load_in_8bit=_load_in_8bit
model_name,
cache_dir=cache_path,
device_map=device,
load_in_8bit=_load_in_8bit,
llm_int8_enable_fp32_cpu_offload=_offload_state_dict,
offload_state_dict=_llm_int8_enable_fp32_cpu_offload,
max_memory=max_memory,
offload_folder=f"{cache_path}/weights/",
)

return model, tokenizer
except Exception as e:
logger.info(f"An error occurred while loading the model: {e}")
Expand Down

0 comments on commit 80cb1f4

Please sign in to comment.