Skip to content

Commit

Permalink
Workaround to manage context length #4
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Aug 15, 2023
1 parent d38e3c0 commit 02869ed
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
8 changes: 2 additions & 6 deletions sidekick/configs/data_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
}
"""

schema_info_template = {
"Column Name": "",
"Column Type": "",
"Sample Values": []
}
schema_info_template = {"Column Name": "", "Column Type": "", "Sample Values": []}

data_samples_template = "Column {column_name} contains values similar to {comma_separated_sample_values}."
data_samples_template = "'{column_name}' contains values similar to {comma_separated_sample_values}."
26 changes: 25 additions & 1 deletion sidekick/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def generate_sql(
_val = _context.get(_item, None)
if _val:
contextual_context.append(f"{_item}: {_val}")
# Caution:
contextual_context.append(f"Always use: LIMIT 100 with SELECT")

logger.info("Filtering Question/Query pairs ...")
context_queries: list = self.update_context_queries()
Expand Down Expand Up @@ -352,6 +354,7 @@ def generate_sql(

logger.debug(f"Relevant sample column values: {data_samples_list}")
_table_name = ", ".join(table_names)

query = NSQL_QUERY_PROMPT.format(
table_name=_table_name,
column_info=_column_info,
Expand All @@ -365,6 +368,27 @@ def generate_sql(
inputs = self.tokenizer([query], return_tensors="pt")
input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
logger.info(f"Context length: {input_length}")

# Handle limited context length
# Currently, conservative approach: remove column description from the prompt, if input_length > (2048-300)
# Others to try:
# 1. Move to a model with larger context length
# 2. Possibly use a different tokenizer for chunking
# 3. Maybe positional interpolation --> https://arxiv.org/abs/2306.15595
if int(input_length) > 1748:
logger.info("Input length is greater than 1748, removing column description from the prompt")
query = NSQL_QUERY_PROMPT.format(
table_name=_table_name,
column_info=_column_info,
data_info_detailed="",
sample_queries=qna_samples,
context=contextual_context_val,
question_txt=input_question,
)
logger.debug(f"Adjusted query Text:\n {query}")
inputs = self.tokenizer([query], return_tensors="pt")
input_length = 1 if self.model.config.is_encoder_decoder else inputs.input_ids.shape[1]
logger.info(f"Adjusted context length: {input_length}")
# Generate SQL
random_seed = random.randint(0, 50)
torch.manual_seed(random_seed)
Expand All @@ -385,7 +409,7 @@ def generate_sql(
# Below is a pre-caution in-case of an error in table name during generation
# COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite
_temp = _res.replace("table_name", table_names[0]).split(";")[0]
res = "SELECT" + _temp + " COLLATE NOCASE;"
res = "SELECT" + _temp + ";"

# Validate the generate SQL for parsing errors, along with dialect specific validation
# Note: Doesn't do well with handling date-time conversions
Expand Down

0 comments on commit 02869ed

Please sign in to comment.