diff --git a/.appignore b/.appignore new file mode 100644 index 0000000..76a5059 --- /dev/null +++ b/.appignore @@ -0,0 +1,17 @@ +.DS_Store +venv/ +.sidekickvenv/ +var/ +.git/ +.idea/ +*/__pycache__/ +scripts/ +setup_cythonize/ +.sh +build/ +dist/ +tests/ +ci/ +examples/sleep_eda/ +examples/telemetry/ +.log diff --git a/.gitignore b/.gitignore index 68bc17f..2fe6b76 100644 --- a/.gitignore +++ b/.gitignore @@ -157,4 +157,12 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ + +# osx generated files +.DS_Store +.DS_Store? +.Trashes +ehthumbs.db +Thumbs.db +.idea diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..3a390d4 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,5 @@ +{ + "recommendations": [ + "ms-python.python" + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c2f20fe --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,15 @@ +{ + "[python]": { + "editor.tabSize": 4, + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "files.eol": "\n", + "files.insertFinalNewline": true, + "files.trimFinalNewlines": true, + "files.trimTrailingWhitespace": true, + "python.formatting.provider": "none", + "python.linting.enabled": true, + "python.linting.flake8Enabled": true, + "python.formatting.blackArgs": ["--line-length", "120"], + "python.linting.flake8Args": ["--max-line-length=120"], +} diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..25ae411 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..cf44ca7 --- /dev/null +++ b/Makefile @@ -0,0 +1,27 @@ +demo_data = s3cmd get --recursive --skip-existing s3://h2o-sql-sidekick-public/demo/sleepEDA/ ./examples/demo/ + +.PHONY: download_demo_data + +all: download_demo_data + +setup: download_demo_data ## Setup + python3 -m venv .sidekickvenv + ./.sidekickvenv/bin/python3 -m pip install --upgrade pip + ./.sidekickvenv/bin/python3 -m pip install wheel + ./.sidekickvenv/bin/python3 -m pip install -r requirements.txt + mkdir -p ./examples/demo/ + + +download_demo_data: + mkdir -p ./examples/demo/ + $(demo_data) + +run: + ./.sidekickvenv/bin/python3 start.py + +clean: + rm -rf ./db + rm -rf ./var + +cloud_bundle: + h2o bundle -L debug 2>&1 | tee -a h2o-bundle.log diff --git a/README.md b/README.md index 2ad6d58..8443994 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,98 @@ # sql-sidekick -A simple sql assistant +A simple SQL assistant (WIP) +Turn β˜… into ⭐ (top-right corner) if you like the project! πŸ™ + +## Motivation +- Historically, it’s common for data to be stored in Databases, democratizing insight generation. +- Enable a helpful assistant to help write complex queries across different database dialects with acceptable efficient execution accuracy (not just matching accuracy) +- Push to derive consistent generation without errors using smaller OSS models to save on compute costs. +- Provide a toolkit for users to mix and match different model sizes to optimize compute cost - e.g., smaller models for generation, remote bigger models for syntax correction or spell correction … +- Build a smart search engine for Databases/structured data, Text to SQL as a Natural Language interface (NLI) for data analysis + + +## Key Features +- An interactive UI to capture feedback along with a python-client and CLI mode. +- Ability for auto DB schema generation for input data using custom input format. +- Support for in-context learning (ICL) pipeline with RAG support to control hallucination +- Guardrails: to check for SQL injections via SELECT statements, e.g., `SELECT * FROM SleepStudy WHERE user_id = 11 OR 1=1;` +- Entity mapping/Schema linking: Ability to build memory for mapping business context to the data schema dynamically; **Note: currently enabled only via CLI, others WIP. +- Ability to save the chat history of query/answer pairs for future reference and improvements. +- Self-correction loop back: Validates syntactic correction of generation. **Note: Self-correction is currently enabled for all openAI GPT models. WIP for other OSS models. +- Integration with different database dialects - currently, SQLite/Postgres(_might be broken temporarily_)/Databricks is enabled. WIP to add support for Duckdb and others. +- Debug mode: Ability to evaluate/modify and validate SQL query against the configured database via UI +- Recommend sample questions: Often, given a dataset, we are unsure what to ask. To come around this problem, we have enabled the ability to generate recommendations for possible questions. + +# Installation +## Dev +``` +1. git clone git@github.com:h2oai/sql-sidekick.git +2. cd sql-sidekick +3. make setup +4. source ./.sidekickvenv/bin/activate +5. poetry install (in case there is an error, try `poetry update` before `poetry install`) +6. python sidekick/prompter.py +``` +## Usage +``` +Dialect: postgres +- docker pull postgres (will pull the latest version) +- docker run --rm --name pgsql-dev -e POSTGRES_PASSWORD=abc -p 5432:5432 postgres + +Default: sqlite +Step: +- Download and install .whl --> s3://sql-sidekick/releases/sql_sidekick-0.0.3-py3-none-any.whl +- python3 -m venv .sidekickvenv +- source .sidekickvenv/bin/activate +- python3 -m pip install sql_sidekick-0.0.3-py3-none-any.whl +``` +## Start +``` +`sql-sidekick` + +Welcome to the SQL Sidekick! I am an AI assistant that helps you with SQL +queries. I can help you with the following: + 0. Generate input schema: + `sql-sidekick configure generate_schema configure generate_schema --data_path "./sample_passenger_statisfaction.csv" --output_path "./table_config.jsonl"` + + 1. Configure a local database(for schema validation and syntax checking): + `sql-sidekick configure db-setup -t "/table_info.jsonl"` (e.g., format --> https://github.com/h2oai/sql-sidekick/blob/main/examples/telemetry/table_info.jsonl) + + 2. Ask a question: `sql-sidekick query -q "avg Gpus" -s "/samples.csv"` (e.g., format --> https://github.com/h2oai/sql-sidekick/blob/main/examples/telemetry/samples.csv) + + 3. Learn contextual query/answer pairs: `sql-sidekick learn add-samples` (optional) + + 4. Add context as key/value pairs: `sql-sidekick learn update-context` (optional) + +Options: + --version Show the version and exit. + --help Show this message and exit. + +Commands: + configure Helps in configuring local database. + learn Helps in learning and building memory. + query Asks question and returns SQL +``` + +## UI +### Steps to start locally +1. Download wave serve [0.26.3](https://github.com/h2oai/wave/releases/tag/v0.26.3) +2. `tar -xzf wave-0.26.3-linux-amd64`; `./waved -max-request-size="20M"` +3. Download the latest bundle: https://github.com/h2oai/sql-sidekick/releases/latest +4. unzip `ai.h2o.wave.sql-sidekick.x.x.x.wave` +5. make setup +6. source ./.sidekickvenv/bin/activate +7. make run +Screen Shot 2023-11-15 at 6 19 14 PM + +## Citation & Acknowledgment +Please consider citing our project if you find it useful: + +```bibtex +@software{sql-sidekick, + title = {{sql-sidekick: A simple SQL assistant}}, + author = {Pramit Choudhary, Michal Malohlava, Narasimha Durgam, Robin Liu, h2o.ai Team} + url = {https://github.com/h2oai/sql-sidekick}, + year = {2024} +} +``` +LLM frameworks adopted: [h2ogpt](https://github.com/h2oai/h2ogpt), [h2ogpte](https://pypi.org/project/h2ogpte/), [LangChain](https://github.com/langchain-ai/langchain), [llama_index](https://github.com/run-llama/llama_index), [openai](https://openai.com/blog/openai-api) diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/about.md b/about.md new file mode 100644 index 0000000..79fbc99 --- /dev/null +++ b/about.md @@ -0,0 +1,12 @@ +**App Goal:** Web client for SQL-Sidekick + +**Target Audience:** Data (Machine Learning) Scientists, Citizen Data Scientists, Data Engineers Managers and Business Analysts + +**Actively Being Maintained:** Yes (Demo release) + +**Last Updated:** January, 2024 + +**Allows uploading and using new model and data:** Yes + +**Detailed Description:** +An experimental demo to evaluate text-to-SQL capabilities of large language models (LLMs) to enable QnA for structured data. diff --git a/app.toml b/app.toml new file mode 100644 index 0000000..c59edcf --- /dev/null +++ b/app.toml @@ -0,0 +1,49 @@ +[App] +Name = "ai.h2o.wave.sql-sidekick" +Title = "SQL-Sidekick" +Description = "QnA with tabular data using NLQ" +LongDescription = "about.md" +InstanceLifecycle = "MANAGED" +Tags = ["DATA_SCIENCE", "MACHINE_LEARNING", "NLP", "GENERATIVE_AI"] +Version = "0.2.2" + +[Runtime] +MemoryLimit = "64Gi" +MemoryReservation = "64Gi" +module = "start" +VolumeMount = "/meta_data" +VolumeSize = "100Gi" +ResourceVolumeSize = "64Gi" +GPUCount = 2 +RuntimeVersion = "ub2004_cuda114_cudnn8_py38_wlatest_a10g" +RoutingMode = "BASE_URL" +EnableOIDC = true + +[[Env]] +Name = "H2O_WAVE_MAX_REQUEST_SIZE" +Value = "20M" + +[[Env]] +Name = "HEAP_ID" +Secret = "heap-analytics" +SecretKey = "id" + +[[Env]] +Name = "H2OGPT_URL" +Secret = "h2ogpt-oss-sqlsidekick" +SecretKey = "h2ogpt-url" + +[[Env]] +Name = "H2OGPT_API_TOKEN" +Secret = "h2ogpt-oss-sqlsidekick" +SecretKey = "h2ogpt-key" + +[[Env]] +Name = "H2OGPTE_URL" +Secret = "h2ogpte-sqlsidekick" +SecretKey = "h2ogpte-url" + +[[Env]] +Name = "H2OGPTE_API_TOKEN" +Secret = "h2ogpte-sqlsidekick" +SecretKey = "h2ogpte-key" diff --git a/eval/README.md b/eval/README.md new file mode 100644 index 0000000..1cb6fa1 --- /dev/null +++ b/eval/README.md @@ -0,0 +1,20 @@ +Usage: +1. `python3 -m venv .sidekick_venv` +2. `source .sidekick_venv/bin/activate` +3. `pip install --force-reinstall sql_sidekick-x.x.x-py3-none-any.whl` (# replace x.x.x with the latest version number), https://github.com/h2oai/sql-sidekick/releases +4. `python eval/inference.py run-eval --help` +``` +Options: + -i, --input_data_path TEXT Path to dataset in .csv format + -t, --table_name TEXT Table name related to the supplied dataset + -e, --eval_data_path TEXT Path to eval dataset in .csv format + -m, --model_name TEXT Model name to use for inference + -s, --sample_qna_path TEXT Path to sample QnA in .csv format + -n, --iterations INTEGER Number of iterations to run + -th, --threshold FLOAT Similarity threshold + -k, --kwargs TEXT Additional arguments +``` +5. `python eval/inference.py run-eval -i -t "your_table_name" -e -s -m "h2ogpt-sql-sqlcoder-34b-alpha"` + + +Benchmarks: WIP diff --git a/eval/inference.py b/eval/inference.py new file mode 100644 index 0000000..ac767e0 --- /dev/null +++ b/eval/inference.py @@ -0,0 +1,116 @@ +# 1. python3 -m venv .sidekick_venv +# 2. source .sidekick_venv/bin/activate +# 3. pip install --force-reinstall sql_sidekick-x.x.x-py3-none-any.whl (# replace x.x.x with the latest version number) + +import os +from pathlib import Path +from typing import Optional + +import click +import pandas as pd +from sidekick.prompter import ask, db_setup +from sidekick.query import SQLGenerator +from sidekick.schema_generator import generate_schema +from sidekick.utils import generate_text_embeddings, setup_dir +from sklearn.metrics.pairwise import cosine_similarity + +base_path = (Path(__file__).parent / "../").resolve() +base_path = "." +cache_path = f"{base_path}/var/lib/tmp" +setup_dir(base_path) + +HOST_NAME = "localhost" +USER_NAME = "sqlite" +PASSWORD = "abc" +DB_NAME = "query_eval_db" +PORT = "5432" + +# Initialize DB +if Path(f"{base_path}/db/sqlite/{DB_NAME}.db").exists(): + os.remove(f"{base_path}/db/sqlite/{DB_NAME}.db") + +def compute_similarity_score(x1: str, x2:str): + m_path = f"{base_path}/models/sentence_transformers/" + _embedding1 = generate_text_embeddings(m_path, x=[x1, x2]) + _embedding2 = generate_text_embeddings(m_path, x=[x2]) + similarities_score = cosine_similarity(_embedding1.astype(float), _embedding2.astype(float)) + return similarities_score + + +@click.group() +@click.version_option() +def cli(): + """For benchmarking SQL-Sidekick. + """ + +@cli.command() +@click.option("--input_data_path", "-i", help="Path to dataset in .csv format") +@click.option("--table_name", "-t", help="Table name related to the supplied dataset") +@click.option("--eval_data_path", "-e", help="Path to eval dataset in .csv format") +@click.option("--model_name", "-m", default="h2ogpt-sql-sqlcoder-34b-alpha", help="Model name to use for inference") +@click.option("--sample_qna_path", "-s", default=None, help="Path to sample QnA in .csv format") +@click.option("--iterations", "-n", default=1, help="Number of iterations to run") +@click.option("--threshold", "-th", default=0.9, help="Similarity threshold") +@click.option("--kwargs", "-k", default=None, help="Additional arguments") +def run_eval(input_data_path: str, table_name: str, eval_data_path: str, model_name: str, iterations: int, threshold: float, sample_qna_path: Optional[str]=None, **kwargs): + # Generate schema for the supplied input data + _, table_info_path = generate_schema(data_path=input_data_path, output_path=f"{cache_path}/{table_name}_table_info.jsonl") + # Db setup + _, err = db_setup( + db_name=DB_NAME, + hostname=HOST_NAME, + user_name=USER_NAME, + password=PASSWORD, + port=PORT, + table_info_path=table_info_path, + table_samples_path=input_data_path, + table_name=table_name, + local_base_path=base_path + ) + + # read gold context + syntax_accuracy = {} + failures = {} + compare_df = pd.read_csv(eval_data_path) + count = 0 + while count < iterations: + for _row in compare_df.itertuples(): + input_q = _row.question + expected_sql = _row.answer + + # With self-correction + _generated_sql = '' + result, _, _ = ask( + question=input_q, + table_info_path=table_info_path, + sample_queries_path=sample_qna_path, + table_name=table_name, + is_command=False, + model_name=model_name, + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=False, + guardrails=False, + self_correction=True + ) + + if result and len(result) > 0: + _idx = [result.index(_r) for _r in result if _r.startswith("``` sql")] + if _idx: + _generated_sql = str(result[_idx[0]].split("``` sql\n")[1]).replace("```", "").strip() + split_text = _generated_sql.split("\n") + _generated_sql = " ".join(split_text).strip() + _syntax_score = compute_similarity_score(expected_sql, _generated_sql) + + if _syntax_score[0][0] > threshold: + syntax_accuracy[input_q] = _syntax_score[0][0] + else: + failures[input_q] = (expected_sql, _generated_sql) + count+=1 + print(f"Syntax accuracy: {float(len(syntax_accuracy)/compare_df.shape[0])}") + print(f"Failures cases: {failures}") + +if __name__ == "__main__": + cli() diff --git a/examples/notebooks/Guardrails_SQL_injection.ipynb b/examples/notebooks/Guardrails_SQL_injection.ipynb new file mode 100644 index 0000000..5bc921f --- /dev/null +++ b/examples/notebooks/Guardrails_SQL_injection.ipynb @@ -0,0 +1,677 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8cc48531-3715-4285-b941-6c5c400cc399", + "metadata": {}, + "source": [ + "**Guardrails:** are important to prevent misuse and mitigate concerns on bad actors. Within sql-sidekick, we have added support for scanning malicious patterns that might occur in generated SELECT statements that might destroy the database or compromise sensitive information. \n", + "We adopted a 2-step approach, using pattern matching and LLM prompting to ensure safety and compliance in a product setting. The self-corrective loop with external feedback helps maintain syntactical integrity.\n", + "\n", + "_Reference:_\n", + "1. https://github.com/payloadbox/sql-injection-payload-list#generic-sql-injection-payloads\n", + "2. https://www.invicti.com/blog/web-security/sql-injection-cheat-sheet/#InlineSamples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed26aaa3-ec95-4ef9-9a93-05f3c629fe90", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install --force-reinstall sql_sidekick-0.x.x-py3-none-any.whl" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "b336798f-43c6-4ad8-96ba-0cd43ccd8c9d", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "337f3cf7-a366-4a2e-9746-3a0b8c5e5a1a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'0.2.0'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sidekick import prompter\n", + "prompter.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0a55a6ac-65c3-47e3-ab3e-0818800ba437", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "from pathlib import Path\n", + "\n", + "\n", + "from sidekick.prompter import ask, db_setup\n", + "from sidekick.query import SQLGenerator\n", + "from sidekick.schema_generator import generate_schema\n", + "from sidekick.utils import setup_dir\n", + "\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "78cd0f95-5366-4b3e-b82a-3d5162ab218d", + "metadata": {}, + "outputs": [], + "source": [ + "# env variables\n", + "\n", + "os.environ['OPENAI_API_KEY'] = \"\"\n", + "os.environ['H2OGPT_URL'] = ''\n", + "os.environ['H2OGPT_API_TOKEN'] = \"\"\n", + "os.environ['H2OGPTE_URL'] = \"https://h2ogpte.genai.h2o.ai\" # e.g. https://<>.h2ogpte.h2o.ai\n", + "os.environ['H2OGPTE_API_TOKEN'] = \"\"\n", + "\n", + "\n", + "base_path = \".\"\n", + "cache_path = f\"{base_path}/var/lib/tmp\"\n", + "setup_dir(base_path)\n", + "\n", + "HOST_NAME = \"localhost\"\n", + "USER_NAME = \"sqlite\"\n", + "PASSWORD = \"abc\"\n", + "DB_NAME = \"query_test\"\n", + "PORT = \"5432\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "93071d24-acfa-45e2-9f96-14c6ba87fa82", + "metadata": {}, + "outputs": [], + "source": [ + "data_path = \"./sleep_health_and_lifestyle_dataset.csv\"\n", + "# Replace table_name as needed\n", + "table_name = \"sleep_health_and_lifestyle\"\n", + "table_schema, table_info_path = generate_schema(data_path=data_path, output_path=f\"{cache_path}/{table_name}_table_info.jsonl\")\n", + "\n", + "# Initialize DB\n", + "if Path(f\"{base_path}/db/sqlite/{DB_NAME}.db\").exists():\n", + " os.remove(f\"{base_path}/db/sqlite/{DB_NAME}.db\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "67ea24ce-c170-4977-beb3-f7471c90003c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Information supplied:\n", + " query_test, localhost, sqlite, abc, 5432\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 22:24:26.138\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mcreate_db\u001b[0m:\u001b[36m108\u001b[0m - \u001b[34m\u001b[1mSQLite DB is created successfully.\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Database created successfully!\n", + "Table name: sleep_health_and_lifestyle\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 22:24:26.142\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36m_extract_schema_info\u001b[0m:\u001b[36m162\u001b[0m - \u001b[34m\u001b[1mUsing schema information from: ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.144\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mcreate_table\u001b[0m:\u001b[36m186\u001b[0m - \u001b[34m\u001b[1mSchema info used for creating table:\n", + " Person_ID NUMERIC,\n", + "Gender TEXT COLLATE NOCASE,\n", + "Age NUMERIC,\n", + "Occupation TEXT COLLATE NOCASE,\n", + "Sleep_Duration NUMERIC,\n", + "Quality_of_Sleep NUMERIC,\n", + "Physical_Activity_Level NUMERIC,\n", + "Stress_Level NUMERIC,\n", + "BMI_Category TEXT COLLATE NOCASE,\n", + "Blood_Pressure TEXT COLLATE NOCASE,\n", + "Heart_Rate NUMERIC,\n", + "Daily_Steps NUMERIC,\n", + "Sleep_Disorder TEXT COLLATE NOCASE\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.155\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mcreate_table\u001b[0m:\u001b[36m198\u001b[0m - \u001b[1mTable created: sleep_health_and_lifestyle\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checked table sleep_health_and_lifestyle exists in the DB.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 22:24:26.159\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m222\u001b[0m - \u001b[34m\u001b[1mAdding sample values to table: ./sleep_health_and_lifestyle_dataset.csv\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.164\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m228\u001b[0m - \u001b[34m\u001b[1mInserting chunk: 0\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.264\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m233\u001b[0m - \u001b[1mData inserted into table: sleep_health_and_lifestyle\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.267\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m238\u001b[0m - \u001b[1mNumber of rows inserted: 374\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created a Database query_test. Inserted sample values from ./sleep_health_and_lifestyle_dataset.csv into table sleep_health_and_lifestyle, please ask questions!\n" + ] + } + ], + "source": [ + "# Set add_sample=False if no need to add rows to the table\n", + "response , err = db_setup(\n", + " db_name=DB_NAME,\n", + " hostname=HOST_NAME,\n", + " user_name=USER_NAME,\n", + " password=PASSWORD,\n", + " port=PORT,\n", + " table_info_path=table_info_path,\n", + " table_samples_path=data_path,\n", + " table_name=table_name,\n", + " local_base_path=base_path,\n", + " add_sample=True\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "1760df87-0a90-48be-a226-984d082a2481", + "metadata": {}, + "source": [ + "## Example 1:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b1349eb7-4ba8-4a44-9956-02ed3b1b4c63", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 22:24:26.321\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m490\u001b[0m - \u001b[1mTable in use: ['sleep_health_and_lifestyle']\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.322\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m491\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.323\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m524\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.324\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m525\u001b[0m - \u001b[1mQuestion: Execute SQL:\n", + "\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.325\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m543\u001b[0m - \u001b[34m\u001b[1mTable info path: ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.326\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m353\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:26.327\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m357\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.846\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m360\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.848\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m361\u001b[0m - \u001b[1mFree GPU memory: 21GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.860\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m467\u001b[0m - \u001b[1mUsing information info from path ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.861\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mExecuting user provided SQL without generation...\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.862\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m598\u001b[0m - \u001b[1mInput query: Execute SQL:\n", + "\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.863\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m599\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "select * from sleep_health_and_lifestyle where person_id = 105;\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.872\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m625\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "None\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.874\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mexecute_query\u001b[0m:\u001b[36m269\u001b[0m - \u001b[34m\u001b[1mExecuting query:\n", + " select * from sleep_health_and_lifestyle where person_id = 105;\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The results are:\n", + " [(105, 'Female', 36, 'Teacher', 7.2, 8, 60, 4, 'Normal', '115/75', 68, 7000, 'Sleep Apnea')]\n", + "Exiting...\n" + ] + } + ], + "source": [ + "# Example lets consider a simple query as mentioned below,\n", + "# Expected output: a single row having information about person_id=105\n", + "\n", + "input_q = \"\"\"\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105;\n", + " \"\"\"\n", + "\n", + "# We are running in Debug Mode as we are validating SQL statements directly `debug_mode=True`\n", + "result = None\n", + "question = f\"Execute SQL:\\n{input_q}\"\n", + "\n", + "result, _, error = ask(\n", + " question=question,\n", + " table_info_path=table_info_path,\n", + " sample_queries_path=None,\n", + " table_name=table_name,\n", + " is_command=False,\n", + " model_name=\"h2ogpt-sql-nsql-llama-2-7B\",\n", + " is_regenerate=False,\n", + " is_regen_with_options=False,\n", + " execute_query=True,\n", + " guardrails=False,\n", + " local_base_path=base_path,\n", + " debug_mode=True,\n", + " self_correction=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "3bbbd22f-a65e-4864-8ff6-716ace096e4a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 22:24:27.939\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m490\u001b[0m - \u001b[1mTable in use: ['sleep_health_and_lifestyle']\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.941\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m491\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.942\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m524\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.942\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m525\u001b[0m - \u001b[1mQuestion: Execute SQL:\n", + "\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105 OR 1=1;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.943\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m543\u001b[0m - \u001b[34m\u001b[1mTable info path: ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.944\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m353\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.945\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m357\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.946\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m360\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.947\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m361\u001b[0m - \u001b[1mFree GPU memory: 21GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.957\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m467\u001b[0m - \u001b[1mUsing information info from path ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.958\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mExecuting user provided SQL without generation...\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.959\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m598\u001b[0m - \u001b[1mInput query: Execute SQL:\n", + "\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105 OR 1=1;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.960\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m599\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "select * from sleep_health_and_lifestyle where person_id = 105 or 1=1;\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.962\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m625\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "None\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:27.964\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mexecute_query\u001b[0m:\u001b[36m269\u001b[0m - \u001b[34m\u001b[1mExecuting query:\n", + " select * from sleep_health_and_lifestyle where person_id = 105 or 1=1;\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Preview of the result:\n", + " Person_ID Gender Age Occupation Sleep_Duration Quality_of_Sleep \\\n", + "0 1 Male 27 Software Engineer 6.1 6 \n", + "1 2 Male 28 Doctor 6.2 6 \n", + "\n", + " Physical_Activity_Level Stress_Level BMI_Category Blood_Pressure \\\n", + "0 42 6 Overweight 126/83 \n", + "1 60 8 Normal 125/80 \n", + "\n", + " Heart_Rate Daily_Steps Sleep_Disorder \n", + "0 77 4200 None \n", + "1 75 10000 None \n", + "Exiting...\n", + "Number of rows returned: (101, 13)\n" + ] + } + ], + "source": [ + "# Injecting malicious tokens\n", + "# Example: OR 1=1\n", + "# Below user input is a syntactically valid SQL statement but since `0R 1=1` is always true, instead of returning information\n", + "# about person_id=105; a bad actor might get information about all user.\n", + "\n", + "# Built-in guardrails is disabled\n", + "# Result contains more than one rows\n", + "\n", + "result = None\n", + "input_q = \"\"\"\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105 OR 1=1;\n", + " \"\"\"\n", + "question = f\"Execute SQL:\\n{input_q}\"\n", + "#1. Self correction is disabled\n", + "result, _, error = ask(\n", + " question=question,\n", + " table_info_path=table_info_path,\n", + " sample_queries_path=None,\n", + " table_name=table_name,\n", + " is_command=False,\n", + " model_name=\"h2ogpt-sql-nsql-llama-2-7B\",\n", + " is_regenerate=False,\n", + " is_regen_with_options=False,\n", + " execute_query=True,\n", + " guardrails=False,\n", + " local_base_path=base_path,\n", + " debug_mode=True,\n", + " self_correction=False\n", + ")\n", + "\n", + "# format final output\n", + "import ast\n", + "\n", + "res = ast.literal_eval(result[4])\n", + "final_res = pd.DataFrame(res)\n", + "\n", + "print(f\"Number of rows returned: {final_res.shape}\") " + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "72112ca3-4894-40d0-8115-9aef1e98205b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 22:24:28.043\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m490\u001b[0m - \u001b[1mTable in use: ['sleep_health_and_lifestyle']\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.044\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m491\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.046\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m524\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.047\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m525\u001b[0m - \u001b[1mQuestion: Execute SQL:\n", + "\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105 OR 1=1;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.048\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m543\u001b[0m - \u001b[34m\u001b[1mTable info path: ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.049\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m353\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.051\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m357\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.052\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m360\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.053\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m361\u001b[0m - \u001b[1mFree GPU memory: 21GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.062\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m467\u001b[0m - \u001b[1mUsing information info from path ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.063\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mExecuting user provided SQL without generation...\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.064\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m598\u001b[0m - \u001b[1mInput query: Execute SQL:\n", + "\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105 OR 1=1;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.065\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m599\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "select * from sleep_health_and_lifestyle where person_id = 105 or 1=1;\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.067\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m625\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "None\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.068\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m638\u001b[0m - \u001b[1mChecking for vulnerabilities in the provided SQL: select * from sleep_health_and_lifestyle where person_id = 105 or 1=1;\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:28.075\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mcheck_vulnerability\u001b[0m:\u001b[36m548\u001b[0m - \u001b[34m\u001b[1mRequesting additional scan using configured models\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: Server version 1.3.0-dev5 doesn't match client version 1.2.7: unexpected errors may occur.\n", + "Please install the correct version of H2OGPTE with `pip install h2ogpte==1.3.0-dev5`.\n", + "You can enable strict version checking by passing strict_version_check=True.\n", + "The results are:\n", + " The input question has malicious patterns, **('or', '', ''), or 1=1** that could lead to SQL Injection.\n", + "Sorry, I will not be able to provide an answer.\n", + "Please try rephrasing the question. The query is vulnerable to SQL injection. The use of the OR clause with a constant value (1=1) is a common indicator of an injection attack. This suggests that an attacker may be able to manipulate the query to execute arbitrary SQL code, potentially leading to unauthorized access or modification of data.\n", + "Exiting...\n" + ] + } + ], + "source": [ + "# With guardrails 'ON', we are able prevent such ambigious intent\n", + "# Enabling quardrails is configurable by setting the `guardrails=True/false`\n", + "\n", + "result = None\n", + "input_q = \"\"\"\n", + " SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105 OR 1=1;\n", + " \"\"\"\n", + "question = f\"Execute SQL:\\n{input_q}\"\n", + "\n", + "result, _, error = ask(\n", + " question=question,\n", + " table_info_path=table_info_path,\n", + " sample_queries_path=None,\n", + " table_name=table_name,\n", + " is_command=False,\n", + " model_name=\"h2ogpt-sql-nsql-llama-2-7B\",\n", + " is_regenerate=False,\n", + " is_regen_with_options=False,\n", + " execute_query=True,\n", + " guardrails=True,\n", + " local_base_path=base_path,\n", + " debug_mode=True,\n", + " self_correction=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "62492d4f-72e8-4bb6-8d17-520dc0645b1a", + "metadata": {}, + "source": [ + "## Example 2:\n", + "Classic Inline Comment SQL Injection Attack" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2fd7f05d-ce44-4b45-90c5-4c077dddacdf", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 22:24:37.276\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m490\u001b[0m - \u001b[1mTable in use: ['sleep_health_and_lifestyle']\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.277\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m491\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.278\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m524\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.279\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m525\u001b[0m - \u001b[1mQuestion: Execute SQL:\n", + "\n", + " 10; DROP TABLE sleep_health_and_lifestyle /*;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.280\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m543\u001b[0m - \u001b[34m\u001b[1mTable info path: ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.281\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m353\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.282\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m357\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.283\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m360\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.284\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m361\u001b[0m - \u001b[1mFree GPU memory: 21GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.294\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m467\u001b[0m - \u001b[1mUsing information info from path ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.295\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mExecuting user provided SQL without generation...\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.296\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m598\u001b[0m - \u001b[1mInput query: Execute SQL:\n", + "\n", + " 10; DROP TABLE sleep_health_and_lifestyle /*;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.298\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m599\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "10; drop table sleep_health_and_lifestyle /*;\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.300\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m625\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "None\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.302\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m638\u001b[0m - \u001b[1mChecking for vulnerabilities in the provided SQL: 10; drop table sleep_health_and_lifestyle /*;\u001b[0m\n", + "\u001b[32m2024-01-19 22:24:37.303\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mcheck_vulnerability\u001b[0m:\u001b[36m548\u001b[0m - \u001b[34m\u001b[1mRequesting additional scan using configured models\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: Server version 1.3.0-dev5 doesn't match client version 1.2.7: unexpected errors may occur.\n", + "Please install the correct version of H2OGPTE with `pip install h2ogpte==1.3.0-dev5`.\n", + "You can enable strict version checking by passing strict_version_check=True.\n", + "The results are:\n", + " The input question has malicious patterns, **SQL keywords does not start with SELECT, only SELECT queries are allowed.** that could lead to SQL Injection.\n", + "Sorry, I will not be able to provide an answer.\n", + "Please try rephrasing the question.\n", + "Exiting...\n" + ] + } + ], + "source": [ + "result = None\n", + "input_q = \"\"\"\n", + " 10; DROP TABLE sleep_health_and_lifestyle /*;\n", + " \"\"\"\n", + "question = f\"Execute SQL:\\n{input_q}\"\n", + "\n", + "result, _, error = ask(\n", + " question=question,\n", + " table_info_path=table_info_path,\n", + " sample_queries_path=None,\n", + " table_name=table_name,\n", + " is_command=False,\n", + " model_name=\"h2ogpt-sql-nsql-llama-2-7B\",\n", + " is_regenerate=False,\n", + " is_regen_with_options=False,\n", + " execute_query=True,\n", + " guardrails=True,\n", + " local_base_path=base_path,\n", + " debug_mode=True,\n", + " self_correction=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3d3c8150-7bb7-45bf-84e7-99a5d4607817", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 22:32:45.505\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m490\u001b[0m - \u001b[1mTable in use: ['sleep_health_and_lifestyle']\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.506\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m491\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.508\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m524\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.509\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m525\u001b[0m - \u001b[1mQuestion: Execute SQL:\n", + "\n", + " SELECT /*!32302 1/0, */ 1 FROM sleep_health_and_lifestyle;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.509\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m543\u001b[0m - \u001b[34m\u001b[1mTable info path: ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.510\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m353\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.511\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m357\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.512\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m360\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.513\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m361\u001b[0m - \u001b[1mFree GPU memory: 21GB\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.524\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m467\u001b[0m - \u001b[1mUsing information info from path ./var/lib/tmp/sleep_health_and_lifestyle_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.525\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mExecuting user provided SQL without generation...\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.526\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m598\u001b[0m - \u001b[1mInput query: Execute SQL:\n", + "\n", + " SELECT /*!32302 1/0, */ 1 FROM sleep_health_and_lifestyle;\n", + " \u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.527\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m599\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "select /*!32302 1/0, */ 1 from sleep_health_and_lifestyle;\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.530\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m625\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "None\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.531\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m638\u001b[0m - \u001b[1mChecking for vulnerabilities in the provided SQL: select /*!32302 1/0, */ 1 from sleep_health_and_lifestyle;\u001b[0m\n", + "\u001b[32m2024-01-19 22:32:45.533\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mcheck_vulnerability\u001b[0m:\u001b[36m548\u001b[0m - \u001b[34m\u001b[1mRequesting additional scan using configured models\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: Server version 1.3.0-dev5 doesn't match client version 1.2.7: unexpected errors may occur.\n", + "Please install the correct version of H2OGPTE with `pip install h2ogpte==1.3.0-dev5`.\n", + "You can enable strict version checking by passing strict_version_check=True.\n", + "The results are:\n", + " The input question has malicious patterns, **/*!32302 1/0, */** that could lead to SQL Injection.\n", + "Sorry, I will not be able to provide an answer.\n", + "Please try rephrasing the question. The query contains a comment with a SQL injection vulnerability. The comment `/*!32302 1/0, */` is not properly formatted and could be used to inject malicious SQL code.\n", + "Exiting...\n" + ] + } + ], + "source": [ + "result = None\n", + "input_q = \"\"\"\n", + " SELECT /*!32302 1/0, */ 1 FROM sleep_health_and_lifestyle;\n", + " \"\"\"\n", + "question = f\"Execute SQL:\\n{input_q}\"\n", + "\n", + "result, _, error = ask(\n", + " question=question,\n", + " table_info_path=table_info_path,\n", + " sample_queries_path=None,\n", + " table_name=table_name,\n", + " is_command=False,\n", + " model_name=\"h2ogpt-sql-nsql-llama-2-7B\",\n", + " is_regenerate=False,\n", + " is_regen_with_options=False,\n", + " execute_query=True,\n", + " guardrails=True,\n", + " local_base_path=base_path,\n", + " debug_mode=True,\n", + " self_correction=False\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5260a38a-8fa1-4ac1-96e3-b1108cb1eee9", + "metadata": {}, + "source": [ + "**Note:**\n", + "\n", + "Examples shown above are not an exhaustive list but a WIP.\n", + "We have differrent levels of guardrails enabled for safe execution but at the same time we are updating/improving them as we learn further.\n", + "Don't hesitate to reach out." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/databricks_db.ipynb b/examples/notebooks/databricks_db.ipynb new file mode 100644 index 0000000..f48602f --- /dev/null +++ b/examples/notebooks/databricks_db.ipynb @@ -0,0 +1,1178 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "b69aaf22-a0c1-4f9d-b5b1-2ca0aa7130db", + "metadata": {}, + "outputs": [], + "source": [ + "# https://github.com/h2oai/sql-sidekick/releases\n", + "#!pip install --force-reinstall sql_sidekick-0.x.x-py3-none-any.whl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "746d5a53-4ca8-445d-b795-4f79c77dffcb", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "032a240c-5e9b-4982-b6e9-398ee6ce001e", + "metadata": {}, + "outputs": [], + "source": [ + "from sidekick import prompter" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "5173aef9-0e2d-4789-8c16-eff0c043c747", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'0.1.9'" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompter.__version__" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "4c92083f-c2d3-4bb9-a22d-479b4f40695a", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import sys\n", + "\n", + "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", + "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "023038c9-5758-45ec-a11d-99a64fe423b7", + "metadata": {}, + "outputs": [], + "source": [ + "from sidekick.prompter import db_setup, ask\n", + "from sidekick.schema_generator import generate_schema\n", + "from sidekick.utils import setup_dir, list_db_dialects" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "4bcb8e59-cd09-4f82-8606-0ad481953c04", + "metadata": {}, + "outputs": [], + "source": [ + "base_path = \"./\"\n", + "cache_path = f\"{base_path}/var/lib/tmp\"\n", + "setup_dir(base_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "57217aea-cf04-4710-9784-3fd1eba1970e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['bigquery',\n", + " 'clickhouse',\n", + " 'duckdb',\n", + " 'hive',\n", + " 'mysql',\n", + " 'oracle',\n", + " 'postgres',\n", + " 'presto',\n", + " 'redshift',\n", + " 'snowflake',\n", + " 'spark',\n", + " 'spark2',\n", + " 'sqlite',\n", + " 'starrocks',\n", + " 'tableau',\n", + " 'trino',\n", + " 'tsql',\n", + " 'databricks',\n", + " 'drill',\n", + " 'teradata']" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list_db_dialects()" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "304da340-1e5f-4562-a9b8-f0bd9dee59ff", + "metadata": {}, + "outputs": [], + "source": [ + "# env variables\n", + "import os\n", + "\n", + "os.environ['DATABRICKS_HOST'] = \"<>.cloud.databricks.com\" # e.g. <>.cloud.databricks.com\n", + "os.environ['DATABRICKS_CLUSTER_ID'] = \"0112-211956-zzzzz\" # e.g. 0232-316956-tsltsyz5\n", + "os.environ['DATABRICKS_TOKEN'] = \"\"\n", + "os.environ['OPENAI_API_KEY'] = \"\"\n", + "\n", + "os.environ['H2OGPT_URL'] = 'http://38.128.233.247'\n", + "os.environ['H2OGPT_API_TOKEN'] = \"\"\n", + "os.environ['H2OGPTE_URL'] = \"https://h2ogpte.genai.h2o.ai\" # e.g. https://<>.h2ogpte.h2o.ai\n", + "os.environ['H2OGPTE_API_TOKEN'] = \"\"" + ] + }, + { + "cell_type": "markdown", + "id": "81b24fab-049f-48bd-bafa-aa3da2ed2a2a", + "metadata": {}, + "source": [ + "# Scenario 1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b3c4683-8d1f-42f3-8118-25ebdc0ca1e1", + "metadata": {}, + "outputs": [], + "source": [ + "# 1. New dataset for QnA\n", + "# 2. Generate SQL syntax, but no access to Databricks cluster for runtime validation.\n", + "\n", + "HOST_NAME = \"localhost\"\n", + "USER_NAME = \"sqlite\"\n", + "PASSWORD = \"abc\"\n", + "DB_NAME = \"querydb\"\n", + "PORT = \"5432\"\n", + "\n", + "\n", + "# Given .csv file, auto-generate schema\n", + "# Download dataset --> https://www.kaggle.com/datasets/uom190346a/sleep-health-and-lifestyle-dataset\n", + "# Adjust data_path as needed.\n", + "data_path = \"./examples/demo/sleep_health_and_lifestyle_dataset.csv\"\n", + "table_name = \"sleep_health_eda_demo\"\n", + "\n", + "table_schema, table_info_path = generate_schema(data_path=data_path, output_path=f\"{cache_path}/{table_name}_table_info.jsonl\")\n", + "\n", + "# Set add_sample=False if no need to add rows to the table\n", + "response , err = db_setup(\n", + " db_name=DB_NAME,\n", + " hostname=HOST_NAME,\n", + " user_name=USER_NAME,\n", + " password=PASSWORD,\n", + " port=PORT,\n", + " table_info_path=table_info_path,\n", + " table_samples_path=data_path,\n", + " table_name=table_name,\n", + " local_base_path=base_path,\n", + " add_sample=True\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "fd3de067-1e85-4869-b5fc-9504dca7d273", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'.//var/lib/tmp/sleep_health_eda_demo_table_info.jsonl'" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "table_info_path" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "3a8859c2-20e2-43d9-9dff-72519a786e31", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'Column Name': 'Person_ID', 'Column Type': 'NUMERIC'},\n", + " {'Column Name': 'Gender',\n", + " 'Column Type': 'TEXT',\n", + " 'Sample Values': ['Male', 'Female']},\n", + " {'Column Name': 'Age', 'Column Type': 'NUMERIC'},\n", + " {'Column Name': 'Occupation',\n", + " 'Column Type': 'TEXT',\n", + " 'Sample Values': ['Accountant',\n", + " 'Salesperson',\n", + " 'Engineer',\n", + " 'Software Engineer',\n", + " 'Nurse',\n", + " 'Teacher',\n", + " 'Manager',\n", + " 'Lawyer',\n", + " 'Doctor',\n", + " 'Sales Representative']},\n", + " {'Column Name': 'Sleep_Duration', 'Column Type': 'NUMERIC'},\n", + " {'Column Name': 'Quality_of_Sleep', 'Column Type': 'NUMERIC'},\n", + " {'Column Name': 'Physical_Activity_Level', 'Column Type': 'NUMERIC'},\n", + " {'Column Name': 'Stress_Level', 'Column Type': 'NUMERIC'},\n", + " {'Column Name': 'BMI_Category',\n", + " 'Column Type': 'TEXT',\n", + " 'Sample Values': ['Overweight', 'Normal', 'Obese', 'Normal Weight']},\n", + " {'Column Name': 'Blood_Pressure', 'Column Type': 'TEXT'},\n", + " {'Column Name': 'Heart_Rate', 'Column Type': 'NUMERIC'},\n", + " {'Column Name': 'Daily_Steps', 'Column Type': 'NUMERIC'},\n", + " {'Column Name': 'Sleep_Disorder',\n", + " 'Column Type': 'TEXT',\n", + " 'Sample Values': ['None', 'Sleep Apnea', 'Insomnia']}]" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "table_schema" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "606933f5-ecfb-4671-b9bf-cb531f4143e1", + "metadata": {}, + "outputs": [], + "source": [ + "# One could also decide to pass pre-configured table schema via for table creation\n", + "# Schema format:\n", + "\n", + "# [{'Column Name': 'Person_ID', 'Column Type': 'NUMERIC'},\n", + "# {'Column Name': 'Gender',\n", + "# 'Column Type': 'TEXT',\n", + "# 'Sample Values': ['Male', 'Female']},\n", + "# {'Column Name': 'Age', 'Column Type': 'NUMERIC'},\n", + "# {'Column Name': 'Occupation',\n", + "# 'Column Type': 'TEXT',\n", + "# 'Sample Values': ['Accountant',\n", + "# 'Manager',\n", + "# 'Sales Representative',\n", + "# 'Salesperson',\n", + "# 'Lawyer',\n", + "# 'Nurse',\n", + "# 'Teacher',\n", + "# 'Software Engineer',\n", + "# 'Doctor',\n", + "# 'Scientist']},\n", + "# {'Column Name': 'Sleep_Duration', 'Column Type': 'NUMERIC'},\n", + "# {'Column Name': 'Quality_of_Sleep', 'Column Type': 'NUMERIC'},\n", + "# {'Column Name': 'Physical_Activity_Level', 'Column Type': 'NUMERIC'},\n", + "# {'Column Name': 'Stress_Level', 'Column Type': 'NUMERIC'},\n", + "# {'Column Name': 'BMI_Category',\n", + "# 'Column Type': 'TEXT',\n", + "# 'Sample Values': ['Overweight', 'Normal', 'Obese', 'Normal Weight']},\n", + "# {'Column Name': 'Blood_Pressure', 'Column Type': 'TEXT'},\n", + "# {'Column Name': 'Heart_Rate', 'Column Type': 'NUMERIC'},\n", + "# {'Column Name': 'Daily_Steps', 'Column Type': 'NUMERIC'},\n", + "# {'Column Name': 'Sleep_Disorder',\n", + "# 'Column Type': 'TEXT',\n", + "# 'Sample Values': ['None', 'Sleep Apnea', 'Insomnia']}]\n", + "\n", + "\n", + "\n", + "\n", + "# response, err = db_setup(\n", + "# db_name=DB_NAME,\n", + "# hostname=HOST_NAME,\n", + "# user_name=USER_NAME,\n", + "# password=PASSWORD,\n", + "# port=PORT,\n", + "# table_name=table_name,\n", + "# table_schema = r,\n", + "# local_base_path=base_path,\n", + "# add_sample=False\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "045271a5-a9b3-4b4d-b439-a84d5e4d93fe", + "metadata": {}, + "outputs": [], + "source": [ + "def query(question: str, table_name, table_info_path: str, sample_qna_path: str, regenerate=False, regenerate_with_options=False):\n", + " \"\"\"Asks question and returns SQL.\"\"\"\n", + " base_path = \".\"\n", + " # self_correction is enabled by default, set to False if not needed.\n", + " res = ask(\n", + " question=question,\n", + " table_info_path=table_info_path,\n", + " sample_queries_path=sample_qna_path,\n", + " table_name=table_name,\n", + " db_dialect=\"databricks\",\n", + " is_command=False,\n", + " model_name=\"h2ogpt-sql-sqlcoder-34b-alpha\",\n", + " is_regenerate=regenerate,\n", + " is_regen_with_options=regenerate_with_options,\n", + " execute_query=False,\n", + " local_base_path=base_path,\n", + " )\n", + " sql_str = None\n", + " if res:\n", + " _tmp = res[0][1].split(\"```\")[1].replace(\"sql\", '')\n", + " sql_str = ' '.join(_tmp.split())\n", + " return sql_str" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "389a0464-c834-44dd-84fb-fb0c438e2a2d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 09:35:33.851\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m493\u001b[0m - \u001b[1mTable in use: ['sleep_health_eda_demo']\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.852\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m494\u001b[0m - \u001b[1mSQL dialect for generation: databricks\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.854\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m527\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.855\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m528\u001b[0m - \u001b[1mQuestion: What is the average sleep duration for each gender?\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.856\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m546\u001b[0m - \u001b[34m\u001b[1mTable info path: .//var/lib/tmp/sleep_health_eda_demo_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.857\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m353\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.857\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m357\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.859\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m360\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.860\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m361\u001b[0m - \u001b[1mFree GPU memory: 22GB\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:33.860\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mload_embedding_model\u001b[0m:\u001b[36m98\u001b[0m - \u001b[34m\u001b[1mLoading embedding model from: ./models/sentence_transformers\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "1c00fb4eba874ab5849de7cb228362b6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 13 files: 0%| | 0/13 [00:00> or ->\n", + "- Use prepared statements with parameterized queries to prevent SQL injection\n", + "\n", + "\n", + "### Input:\n", + "For SQL TABLE 'sleep_health_eda_demo' with sample question/answer pairs,\n", + "(), create a valid SQL (dialect:databricks) query to answer the following question:\n", + "What is the average sleep duration for each gender?.\n", + "This query will run on a database whose schema is represented in this string:\n", + "CREATE TABLE 'sleep_health_eda_demo' (['Person_ID NUMERIC, Gender TEXT, Age NUMERIC, Occupation TEXT, Sleep_Duration NUMERIC, Quality_of_Sleep NUMERIC, Physical_Activity_Level NUMERIC, Stress_Level NUMERIC, BMI_Category TEXT, Blood_Pressure TEXT, Heart_Rate NUMERIC, Daily_Steps NUMERIC, Sleep_Disorder TEXT,']\n", + ");\n", + "\n", + "-- Table 'sleep_health_eda_demo', , has sample values ({'sleep_health_eda_demo': [\"'Gender' contains values similar to Male,Female.\", \"'Occupation' contains values similar to Accountant,Salesperson,Engineer,Software Engineer,Nurse,Teacher,Manager,Lawyer,Doctor,Sales Representative.\", \"'BMI_Category' contains values similar to Overweight,Normal,Obese,Normal Weight.\", \"'Sleep_Disorder' contains values similar to None,Sleep Apnea,Insomnia.\"]})\n", + "\n", + "### Response:\n", + "Based on your instructions, here is the SELECT SQL query I have generated to answer the question `What is the average sleep duration for each gender?`:\n", + "```SELECT\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:42.964\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m609\u001b[0m - \u001b[34m\u001b[1mGeneration with default temperature : 0.5\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://38.128.233.247:5000/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST http://38.128.233.247:5000/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST http://38.128.233.247:5000/v1/chat/completions \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 09:35:45.369\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m618\u001b[0m - \u001b[34m\u001b[1mGenerated tokens: \n", + "Gender, AVG(Sleep_Duration) AS Average_Sleep_Duration FROM sleep_health_eda_demo GROUP BY Gender ORDER BY Gender DESC;\n", + "\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:45.375\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m601\u001b[0m - \u001b[1mInput query: What is the average sleep duration for each gender?\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:45.377\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m602\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "SELECT `Gender`, AVG(`Sleep_Duration`) AS `Average_Sleep_Duration` FROM `sleep_health_eda_demo` GROUP BY `Gender` ORDER BY `Gender` DESC LIMIT 100\u001b[0m\n", + "\u001b[32m2024-01-19 09:35:45.384\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m628\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "[]\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exiting...\n" + ] + } + ], + "source": [ + "res = query(\"What is the average sleep duration for each gender?\", table_name=table_name, \n", + " table_info_path=table_info_path, sample_qna_path=None, regenerate=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "2ea92cd8-e362-4359-a1a0-2d07b47aaf71", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'SELECT `Gender`, AVG(`Sleep_Duration`) AS `Average_Sleep_Duration` FROM `sleep_health_eda_demo` GROUP BY `Gender` ORDER BY `Gender` DESC LIMIT 100'" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res" + ] + }, + { + "cell_type": "markdown", + "id": "58cf9a8d-d162-416e-8e13-714ecee2a179", + "metadata": {}, + "source": [ + "# Scenario 2" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "9e916d58-92cb-4adc-9c0c-afa3262c0071", + "metadata": {}, + "outputs": [], + "source": [ + "# 1. Connect to pre-configured Catalog/Schema on Databricks cluster\n", + "# 2. Retrieve information\n", + "# 3. Generate SQL syntax, execute against configured DB for the final result" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "b8409b8c-eab7-44fd-9c1f-b675502fa8d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:databricks.sql.client:Successfully opened session 89def2e0-902d-4230-91f3-0b2f8404e628\n", + "Successfully opened session 89def2e0-902d-4230-91f3-0b2f8404e628\n", + "Successfully opened session 89def2e0-902d-4230-91f3-0b2f8404e628\n" + ] + } + ], + "source": [ + "# Retrieve information\n", + "from sidekick.db_config import DBConfig\n", + "\n", + "DBConfig.dialect = \"databricks\"\n", + "# Using a demo dataset from Databricks Catalog\n", + "config_args = {\n", + " \"catalog\": \"samples\",\n", + " \"schema\": \"nyctaxi\",\n", + " \"cluster_id\": os.environ.get(\"DATABRICKS_CLUSTER_ID\")\n", + "}\n", + "table_name = \"trips\" # sample table related to NYC Taxi dataset\n", + "DBConfig.table_name = table_name\n", + "column_info, table_info_path = DBConfig.get_column_info(output_path=f\"{cache_path}/{table_name}_table_info.jsonl\", **config_args)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "beed4542-52ba-4a6c-b2bb-c613155ad86c", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 09:36:07.870\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m493\u001b[0m - \u001b[1mTable in use: ['trips']\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.871\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m494\u001b[0m - \u001b[1mSQL dialect for generation: databricks\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.873\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m527\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.874\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m528\u001b[0m - \u001b[1mQuestion: Compute average trip distance\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.875\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m546\u001b[0m - \u001b[34m\u001b[1mTable info path: .//var/lib/tmp/trips_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.876\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m353\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.877\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m357\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.878\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m360\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.879\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m361\u001b[0m - \u001b[1mFree GPU memory: 22GB\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:07.880\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mload_embedding_model\u001b[0m:\u001b[36m98\u001b[0m - \u001b[34m\u001b[1mLoading embedding model from: .//models/sentence_transformers\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9fc31d9146bc4ccb849762069ae003c1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 13 files: 0%| | 0/13 [00:00> or ->\n", + "- Use prepared statements with parameterized queries to prevent SQL injection\n", + "\n", + "\n", + "### Input:\n", + "For SQL TABLE 'trips' with sample question/answer pairs,\n", + "(), create a valid SQL (dialect:databricks) query to answer the following question:\n", + "Compute average trip distance.\n", + "This query will run on a database whose schema is represented in this string:\n", + "CREATE TABLE 'trips' (['tpep_pickup_datetime DATETIME, tpep_dropoff_datetime DATETIME, trip_distance FLOAT, fare_amount FLOAT, pickup_zip INTEGER, dropoff_zip INTEGER,']\n", + ");\n", + "\n", + "-- Table 'trips', , has sample values ({})\n", + "\n", + "### Response:\n", + "Based on your instructions, here is the SELECT SQL query I have generated to answer the question `Compute average trip distance`:\n", + "```SELECT\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:13.261\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m609\u001b[0m - \u001b[34m\u001b[1mGeneration with default temperature : 0.5\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:httpx:HTTP Request: POST http://38.128.233.247:5000/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST http://38.128.233.247:5000/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST http://38.128.233.247:5000/v1/chat/completions \"HTTP/1.1 200 OK\"\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 09:36:14.569\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m618\u001b[0m - \u001b[34m\u001b[1mGenerated tokens: AVG(trip_distance) AS average_trip_distance FROM trips;\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:14.572\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m601\u001b[0m - \u001b[1mInput query: Compute average trip distance\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:14.574\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m602\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "SELECT AVG(`trip_distance`) AS `average_trip_distance` FROM `trips` LIMIT 100\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:14.579\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m628\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "[]\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:14.581\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m641\u001b[0m - \u001b[1mChecking for vulnerabilities in the provided SQL: SELECT AVG(`trip_distance`) AS `average_trip_distance` FROM `trips` LIMIT 100\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:14.600\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mcheck_vulnerability\u001b[0m:\u001b[36m548\u001b[0m - \u001b[34m\u001b[1mRequesting additional scan using configured models\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: Server version 1.3.0-dev3 doesn't match client version 1.2.7: unexpected errors may occur.\n", + "Please install the correct version of H2OGPTE with `pip install h2ogpte==1.3.0-dev3`.\n", + "You can enable strict version checking by passing strict_version_check=True.\n", + "INFO:databricks.sql.client:Successfully opened session 2371a48a-b7f8-4497-b844-68fa41ec06f2\n", + "Successfully opened session 2371a48a-b7f8-4497-b844-68fa41ec06f2\n", + "Successfully opened session 2371a48a-b7f8-4497-b844-68fa41ec06f2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 09:36:30.379\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mexecute_query\u001b[0m:\u001b[36m268\u001b[0m - \u001b[34m\u001b[1mExecuting query:\n", + " SELECT AVG(\"trip_distance\") AS \"average_trip_distance\" FROM \"trips\" LIMIT 100\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:30.558\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mexecute_query\u001b[0m:\u001b[36m291\u001b[0m - \u001b[1mError occurred : (databricks.sql.exc.ServerOperationError) \n", + "[PARSE_SYNTAX_ERROR] Syntax error at or near '\"average_trip_distance\"'.(line 1, pos 31)\n", + "\n", + "== SQL ==\n", + "SELECT AVG(\"trip_distance\") AS \"average_trip_distance\" FROM \"trips\" LIMIT 100\n", + "-------------------------------^^^\n", + "\n", + "[SQL: SELECT AVG(\"trip_distance\") AS \"average_trip_distance\" FROM \"trips\" LIMIT 100]\n", + "(Background on this error at: https://sqlalche.me/e/20/4xp6)\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:databricks.sql.client:Closing session 2371a48a-b7f8-4497-b844-68fa41ec06f2\n", + "Closing session 2371a48a-b7f8-4497-b844-68fa41ec06f2\n", + "Closing session 2371a48a-b7f8-4497-b844-68fa41ec06f2\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 09:36:30.689\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m666\u001b[0m - \u001b[1mAttempting to auto-correct the query...\u001b[0m\n", + "\u001b[32m2024-01-19 09:36:30.690\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m669\u001b[0m - \u001b[34m\u001b[1mAttempt: 1\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Warning: Server version 1.3.0-dev3 doesn't match client version 1.2.7: unexpected errors may occur.\n", + "Please install the correct version of H2OGPTE with `pip install h2ogpte==1.3.0-dev3`.\n", + "You can enable strict version checking by passing strict_version_check=True.\n", + "INFO:databricks.sql.client:Successfully opened session 856311f5-20ca-45a9-b58d-b3ea0a427ecc\n", + "Successfully opened session 856311f5-20ca-45a9-b58d-b3ea0a427ecc\n", + "Successfully opened session 856311f5-20ca-45a9-b58d-b3ea0a427ecc\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 09:36:48.508\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mexecute_query\u001b[0m:\u001b[36m268\u001b[0m - \u001b[34m\u001b[1mExecuting query:\n", + " SELECT AVG(`trip_distance`) AS `average_trip_distance` FROM `trips` LIMIT 100\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:databricks.sql.client:Closing session 856311f5-20ca-45a9-b58d-b3ea0a427ecc\n", + "Closing session 856311f5-20ca-45a9-b58d-b3ea0a427ecc\n", + "Closing session 856311f5-20ca-45a9-b58d-b3ea0a427ecc\n", + "The query results are:\n", + " [(2.8528291993434256,)]\n", + "Exiting...\n" + ] + } + ], + "source": [ + "input_q = \"Compute average trip distance\"\n", + "result, _, error = ask(\n", + " question=input_q,\n", + " table_info_path=table_info_path,\n", + " sample_queries_path=None,\n", + " table_name=table_name,\n", + " is_command=False,\n", + " model_name=\"h2ogpt-sql-sqlcoder-34b-alpha\",\n", + " db_dialect=\"databricks\",\n", + " execute_db_dialect=\"databricks\",\n", + " is_regenerate=False,\n", + " is_regen_with_options=False,\n", + " execute_query=True,\n", + " local_base_path=base_path,\n", + " debug_mode=False,\n", + " guardrails=True,\n", + " self_correction=True\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0105c31c-fb21-4f81-a734-1b4c45f74b24", + "metadata": {}, + "outputs": [], + "source": [ + "?ask" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "30caae94-37a3-4563-b311-42833e402c39", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['**Generated response for question,**\\nCompute average trip distance\\n',\n", + " '``` sql\\nSELECT AVG(`trip_distance`) AS `average_trip_distance`\\nFROM `trips`\\nLIMIT 100\\n```\\n\\n',\n", + " '\\n',\n", + " '**Result:** \\n',\n", + " '[(2.8528291993434256,)]',\n", + " '\\n']" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "id": "be504033-2da0-4a02-b710-04d0c5a5eb23", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:databricks.sql.client:Successfully opened session 84ea4e01-818d-4f1c-a4a2-7dbb94cd646a\n", + "Successfully opened session 84ea4e01-818d-4f1c-a4a2-7dbb94cd646a\n", + "Successfully opened session 84ea4e01-818d-4f1c-a4a2-7dbb94cd646a\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-19 09:37:58.015\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mexecute_query\u001b[0m:\u001b[36m268\u001b[0m - \u001b[34m\u001b[1mExecuting query:\n", + " SELECT SUM(`trip_distance`) AS `average_trip_distance`\n", + "FROM `trips`\n", + "LIMIT 1\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:databricks.sql.client:Closing session 84ea4e01-818d-4f1c-a4a2-7dbb94cd646a\n", + "Closing session 84ea4e01-818d-4f1c-a4a2-7dbb94cd646a\n", + "Closing session 84ea4e01-818d-4f1c-a4a2-7dbb94cd646a\n" + ] + }, + { + "data": { + "text/plain": [ + "([(62568.24999999994,)],\n", + " 'SELECT SUM(`trip_distance`) AS `average_trip_distance`\\nFROM `trips`\\nLIMIT 1')" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# One can make further changes and re-execute the query against the configured Database as required.\n", + "DBConfig.execute_query(\"SELECT SUM(`trip_distance`) AS `average_trip_distance`\\nFROM `trips`\\nLIMIT 1\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/notebooks/sdk_quick_tutorial.ipynb b/examples/notebooks/sdk_quick_tutorial.ipynb new file mode 100644 index 0000000..de888d8 --- /dev/null +++ b/examples/notebooks/sdk_quick_tutorial.ipynb @@ -0,0 +1,1144 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 20, + "id": "60080b7e-2e80-4154-aa35-87c13b6ab371", + "metadata": {}, + "outputs": [], + "source": [ + "# https://github.com/h2oai/sql-sidekick/releases\n", + "#!python3 -m pip install --force-reinstall sql_sidekick-0.2.2-py3-none-any.whl" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "f480e37a-4327-48da-8c84-aba0ac1eef23", + "metadata": {}, + "outputs": [], + "source": [ + "import logging\n", + "import sys\n", + "\n", + "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", + "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c91887ce-c74a-432b-a3f9-120c8abc0003", + "metadata": {}, + "outputs": [], + "source": [ + "from sidekick.prompter import db_setup, ask\n", + "from sidekick.schema_generator import generate_schema\n", + "from sidekick.utils import setup_dir, list_models" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9fc212c8-dc73-4330-a07f-7394fd198395", + "metadata": {}, + "outputs": [], + "source": [ + "base_path = \"./\"\n", + "cache_path = f\"{base_path}/var/lib/tmp\"\n", + "setup_dir(base_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6421a995-f846-4a1e-8292-374bd7500382", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "f = pd.read_csv(\"./sleep_health_and_lifestyle_dataset.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "eac0fa65-bb06-415a-aa87-1185789f878d", + "metadata": {}, + "outputs": [], + "source": [ + "# Env variables\n", + "import os\n", + "\n", + "os.environ['OPENAI_API_KEY'] = \"\"\n", + "os.environ['H2OGPT_URL'] = 'http://38.128.233.247'\n", + "os.environ['H2OGPT_API_TOKEN'] = \"\"\n", + "# To get access to h2ogpte endpoint, reach out to cloud-feedback@h2o.ai\n", + "os.environ['H2OGPTE_URL'] = \"https://h2ogpte.genai.h2o.ai\" # e.g. https://<>.h2ogpte.h2o.ai\n", + "os.environ['H2OGPTE_API_TOKEN'] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "62e23b39-caa8-4e2f-bf12-678dd586f0df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Information supplied:\n", + " querydb, localhost, sqlite, abc, 5432\n", + "Database already exists!\n", + "Table name: sleep_health_eda\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-27 20:35:06.568\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36m_extract_schema_info\u001b[0m:\u001b[36m162\u001b[0m - \u001b[34m\u001b[1mUsing schema information from: .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:06.572\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mcreate_table\u001b[0m:\u001b[36m186\u001b[0m - \u001b[34m\u001b[1mSchema info used for creating table:\n", + " Person_ID NUMERIC,\n", + "Gender TEXT COLLATE NOCASE,\n", + "Age NUMERIC,\n", + "Occupation TEXT COLLATE NOCASE,\n", + "Sleep_Duration NUMERIC,\n", + "Quality_of_Sleep NUMERIC,\n", + "Physical_Activity_Level NUMERIC,\n", + "Stress_Level NUMERIC,\n", + "BMI_Category TEXT COLLATE NOCASE,\n", + "Blood_Pressure TEXT COLLATE NOCASE,\n", + "Heart_Rate NUMERIC,\n", + "Daily_Steps NUMERIC,\n", + "Sleep_Disorder TEXT COLLATE NOCASE\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:06.578\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36mcreate_table\u001b[0m:\u001b[36m198\u001b[0m - \u001b[1mTable created: sleep_health_eda\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checked table sleep_health_eda exists in the DB.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-27 20:35:06.586\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m222\u001b[0m - \u001b[34m\u001b[1mAdding sample values to table: ./sleep_health_and_lifestyle_dataset.csv\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:06.597\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m228\u001b[0m - \u001b[34m\u001b[1mInserting chunk: 0\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:06.755\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m233\u001b[0m - \u001b[1mData inserted into table: sleep_health_eda\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:06.759\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.db_config\u001b[0m:\u001b[36madd_samples\u001b[0m:\u001b[36m238\u001b[0m - \u001b[1mNumber of rows inserted: 2618\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Created a Database querydb. Inserted sample values from ./sleep_health_and_lifestyle_dataset.csv into table sleep_health_eda, please ask questions!\n" + ] + } + ], + "source": [ + "HOST_NAME = \"localhost\"\n", + "USER_NAME = \"sqlite\"\n", + "PASSWORD = \"abc\"\n", + "DB_NAME = \"querydb\"\n", + "PORT = \"5432\"\n", + "\n", + "\n", + "# Given .csv file, auto-generate schema\n", + "# Download dataset --> https://www.kaggle.com/datasets/uom190346a/sleep-health-and-lifestyle-dataset\n", + "data_path = \"./sleep_health_and_lifestyle_dataset.csv\"\n", + "table_name = \"sleep_health_eda\"\n", + "\n", + "r, table_info_path = generate_schema(data_path=data_path, output_path=f\"{cache_path}/{table_name}_table_info.jsonl\")\n", + "\n", + "llm_response, err = db_setup(\n", + " db_name=DB_NAME,\n", + " hostname=HOST_NAME,\n", + " user_name=USER_NAME,\n", + " password=PASSWORD,\n", + " port=PORT,\n", + " table_info_path=table_info_path,\n", + " table_samples_path=data_path,\n", + " table_name=table_name,\n", + " local_base_path=base_path\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "80dec22c-362e-41a0-8f34-0690465542e6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['h2ogpt-sql-sqlcoder2-4bit',\n", + " 'h2ogpt-sql-sqlcoder-34b-alpha-4bit',\n", + " 'h2ogpt-sql-nsql-llama-2-7B-4bit',\n", + " 'h2ogpt-sql-sqlcoder2',\n", + " 'h2ogpt-sql-sqlcoder-34b-alpha',\n", + " 'h2ogpt-sql-nsql-llama-2-7B',\n", + " 'gpt-3.5-turbo',\n", + " 'gpt-4-8k',\n", + " 'gpt-4-1106-preview-128k']" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# List supported models\n", + "list_models()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2b3db015-1d9e-46b0-ad58-2f5aac0c6e4c", + "metadata": {}, + "outputs": [], + "source": [ + "# Use `list_models()` to check other supported models\n", + "\n", + "def query(\n", + " question: str,\n", + " table_name,\n", + " table_info_path: str,\n", + " sample_qna_path: str,\n", + " regenerate=False,\n", + " regenerate_with_options=False,\n", + "):\n", + " \"\"\"Asks question and returns SQL.\"\"\"\n", + " base_path = \"./\"\n", + " # self_correction is enabled by default, set to False if not needed.\n", + " res = ask(\n", + " question=question,\n", + " table_info_path=table_info_path,\n", + " sample_queries_path=sample_qna_path,\n", + " table_name=table_name,\n", + " is_command=False,\n", + " model_name=\"h2ogpt-sql-sqlcoder2-4bit\", #Other default model option: h2ogpt-sql-sqlcoder-34b-alpha\n", + " is_regenerate=regenerate,\n", + " is_regen_with_options=regenerate_with_options,\n", + " execute_query=False,\n", + " local_base_path=base_path,\n", + " )\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "523f1a88-eea8-414c-89b1-b7a2b3126535", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-27 20:35:33.226\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m500\u001b[0m - \u001b[1mTable in use: ['sleep_health_eda']\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:33.229\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m501\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:33.231\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m534\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:33.232\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m535\u001b[0m - \u001b[1mQuestion: What is the average sleep duration for each gender?\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:33.234\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m553\u001b[0m - \u001b[34m\u001b[1mTable info path: .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:33.235\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m355\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:33.236\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m359\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:35.049\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m362\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:35.055\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m363\u001b[0m - \u001b[1mFree GPU memory: 20GB\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:35.057\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36m__new__\u001b[0m:\u001b[36m83\u001b[0m - \u001b[1mLoading local model: h2ogpt-sql-sqlcoder2-4bit\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:35.058\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mload_causal_lm_model\u001b[0m:\u001b[36m382\u001b[0m - \u001b[1mTotal GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:35.059\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_load_llm\u001b[0m:\u001b[36m390\u001b[0m - \u001b[1mFree GPU memory: 20GB\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:35.060\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_load_llm\u001b[0m:\u001b[36m393\u001b[0m - \u001b[1mLoading model: defog/sqlcoder2 on device id: 0\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:35.062\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_load_llm\u001b[0m:\u001b[36m394\u001b[0m - \u001b[34m\u001b[1mModel cache: .//models/\u001b[0m\n", + "\u001b[32m2024-01-27 20:35:35.063\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_load_llm\u001b[0m:\u001b[36m432\u001b[0m - \u001b[34m\u001b[1mLoading in 4 bit mode: True with device {'': 0}\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dec7435d27704941a96dcdb9951ed10e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/4 [00:00> or ->\n", + "- Use prepared statements with parameterized queries to prevent SQL injection\n", + "\n", + "\n", + "### Input:\n", + "For SQL TABLE 'sleep_health_eda' with sample question/answer pairs,\n", + "(), create a valid SQL (dialect:sqlite) query to answer the following question:\n", + "What is the average sleep duration for each gender?.\n", + "This query will run on a database whose schema is represented in this string:\n", + "CREATE TABLE 'sleep_health_eda' (['Person_ID NUMERIC, Gender TEXT, Age NUMERIC, Occupation TEXT, Sleep_Duration NUMERIC, Quality_of_Sleep NUMERIC, Physical_Activity_Level NUMERIC, Stress_Level NUMERIC, BMI_Category TEXT, Blood_Pressure TEXT, Heart_Rate NUMERIC, Daily_Steps NUMERIC, Sleep_Disorder TEXT,']\n", + ");\n", + "\n", + "-- Table 'sleep_health_eda', , has sample values ({'sleep_health_eda': [\"'Gender' contains values similar to Male,Female.\", \"'Occupation' contains values similar to Lawyer,Teacher,Doctor,Software Engineer,Scientist,Sales Representative,Accountant,Salesperson,Manager,Nurse.\", \"'BMI_Category' contains values similar to Overweight,Normal,Obese,Normal Weight.\", \"'Sleep_Disorder' contains values similar to None,Sleep Apnea,Insomnia.\"]})\n", + "\n", + "### Response:\n", + "Based on your instructions, here is the SELECT SQL query I have generated to answer the question `What is the average sleep duration for each gender?`:\n", + "```SELECT\u001b[0m\n", + "\u001b[32m2024-01-27 20:36:22.461\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mContext length: 743\u001b[0m\n", + "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n", + "\u001b[32m2024-01-27 20:36:30.891\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m608\u001b[0m - \u001b[1mInput query: What is the average sleep duration for each gender?\u001b[0m\n", + "\u001b[32m2024-01-27 20:36:30.895\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m609\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "SELECT \"gender\", AVG(\"sleep_duration\") AS \"average_sleep_duration\" FROM \"sleep_health_eda\" GROUP BY \"gender\" LIMIT 100\u001b[0m\n", + "\u001b[32m2024-01-27 20:36:30.905\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m635\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "[]\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exiting...\n" + ] + } + ], + "source": [ + "res = query(\"What is the average sleep duration for each gender?\", table_name=\"sleep_health_eda\", \n", + " table_info_path=table_info_path, sample_qna_path=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "b17e2b4f-8736-4d44-addc-db8d2be4ce51", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question = **Generated response for question,**\n", + "What is the average sleep duration for each gender?\n", + "\n", + "----\n", + "Generated SQL = ``` sql\n", + "SELECT \"gender\",\n", + " AVG(\"sleep_duration\") AS \"average_sleep_duration\"\n", + "FROM \"sleep_health_eda\"\n", + "GROUP BY \"gender\"\n", + "LIMIT 100\n", + "```\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(f\"Question = {res[0][0]}\")\n", + "print(\"----\")\n", + "print(f\"Generated SQL = {res[0][1]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "03c5dfc0-c6f0-4573-b36d-56dc7bcbe8bc", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-27 20:39:50.016\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m500\u001b[0m - \u001b[1mTable in use: ['sleep_health_eda']\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.017\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m501\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.018\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m534\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.019\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m535\u001b[0m - \u001b[1mQuestion: What are the most common occupations among individuals in the dataset?\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.020\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m553\u001b[0m - \u001b[34m\u001b[1mTable info path: .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.021\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m355\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.022\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m359\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.023\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m362\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.024\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m363\u001b[0m - \u001b[1mFree GPU memory: 8GB\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.038\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m469\u001b[0m - \u001b[1mUsing information info from path .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.039\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m603\u001b[0m - \u001b[1mComputing user request ...\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.043\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36msemantic_search\u001b[0m:\u001b[36m155\u001b[0m - \u001b[34m\u001b[1mInput questions: # query: what are the most common occupations among individuals in the dataset?\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e4589b85d3514f2ea3c88a505f15698c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Batches: 0%| | 0/1 [00:00> or ->\n", + "- Use prepared statements with parameterized queries to prevent SQL injection\n", + "\n", + "\n", + "### Input:\n", + "For SQL TABLE 'sleep_health_eda' with sample question/answer pairs,\n", + "(), create a valid SQL (dialect:sqlite) query to answer the following question:\n", + "What are the most common occupations among individuals in the dataset?.\n", + "This query will run on a database whose schema is represented in this string:\n", + "CREATE TABLE 'sleep_health_eda' (['Person_ID NUMERIC, Gender TEXT, Age NUMERIC, Occupation TEXT, Sleep_Duration NUMERIC, Quality_of_Sleep NUMERIC, Physical_Activity_Level NUMERIC, Stress_Level NUMERIC, BMI_Category TEXT, Blood_Pressure TEXT, Heart_Rate NUMERIC, Daily_Steps NUMERIC, Sleep_Disorder TEXT,']\n", + ");\n", + "\n", + "-- Table 'sleep_health_eda', , has sample values ({'sleep_health_eda': [\"'Gender' contains values similar to Male,Female.\", \"'Occupation' contains values similar to Lawyer,Teacher,Doctor,Software Engineer,Scientist,Sales Representative,Accountant,Salesperson,Manager,Nurse.\", \"'BMI_Category' contains values similar to Overweight,Normal,Obese,Normal Weight.\", \"'Sleep_Disorder' contains values similar to None,Sleep Apnea,Insomnia.\"]})\n", + "\n", + "### Response:\n", + "Based on your instructions, here is the SELECT SQL query I have generated to answer the question `What are the most common occupations among individuals in the dataset?`:\n", + "```SELECT\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.161\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mContext length: 749\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.162\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m651\u001b[0m - \u001b[1mRegeneration requested on previous query ...\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:50.163\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m652\u001b[0m - \u001b[34m\u001b[1mSelected temperature for fast regeneration : 0.8\u001b[0m\n", + "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n", + "\u001b[32m2024-01-27 20:39:52.499\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m675\u001b[0m - \u001b[34m\u001b[1mTemperature saved: 0.8\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:52.512\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m608\u001b[0m - \u001b[1mInput query: What are the most common occupations among individuals in the dataset?\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:52.513\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m609\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "SELECT \"occupation\", COUNT(1) AS \"COUNT\" FROM \"sleep_health_eda\" GROUP BY \"occupation\" ORDER BY \"COUNT\" DESC LIMIT 100\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:52.516\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m635\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "[]\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exiting...\n" + ] + } + ], + "source": [ + "# On using re-generation flag we toggle the temperature values between 0 and 1 alternating between low \n", + "# (focus/conservative generation and high values (random/creative generation)\n", + "res = query(\"What are the most common occupations among individuals in the dataset?\", table_name=\"sleep_health_eda\", \n", + " table_info_path=table_info_path, sample_qna_path=None, regenerate=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "cf2fc33d-ea21-4ab2-9019-329f5bc2051d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question = **Generated response for question,**\n", + "What are the most common occupations among individuals in the dataset?\n", + "\n", + "----\n", + "Generated SQL = ``` sql\n", + "SELECT \"occupation\",\n", + " COUNT(1) AS \"COUNT\"\n", + "FROM \"sleep_health_eda\"\n", + "GROUP BY \"occupation\"\n", + "ORDER BY \"COUNT\" DESC\n", + "LIMIT 100\n", + "```\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(f\"Question = {res[0][0]}\")\n", + "print(\"----\")\n", + "print(f\"Generated SQL = {res[0][1]}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "b47bef8d-c991-4581-a7fc-23a056911c3f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2024-01-27 20:39:56.595\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m500\u001b[0m - \u001b[1mTable in use: ['sleep_health_eda']\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.597\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m501\u001b[0m - \u001b[1mSQL dialect for generation: sqlite\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.598\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m534\u001b[0m - \u001b[1mSetting context...\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.599\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m535\u001b[0m - \u001b[1mQuestion: What is the average sleep duration for each gender?\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.601\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m553\u001b[0m - \u001b[34m\u001b[1mTable info path: .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.602\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m355\u001b[0m - \u001b[1mNumber of GPUs: 1\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.604\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m359\u001b[0m - \u001b[34m\u001b[1mInformation on device: 0\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.605\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m362\u001b[0m - \u001b[1mTotal Memory: 23GB\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.607\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36mis_resource_low\u001b[0m:\u001b[36m363\u001b[0m - \u001b[1mFree GPU memory: 8GB\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.629\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36m_check_file_info\u001b[0m:\u001b[36m469\u001b[0m - \u001b[1mUsing information info from path .//var/lib/tmp/sleep_health_eda_table_info.jsonl\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.631\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m603\u001b[0m - \u001b[1mComputing user request ...\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.640\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.utils\u001b[0m:\u001b[36msemantic_search\u001b[0m:\u001b[36m155\u001b[0m - \u001b[34m\u001b[1mInput questions: # query: what is the average sleep duration for each gender?\u001b[0m\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5ed8c4c529c54952a30bfb4d99b7ec95", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Batches: 0%| | 0/1 [00:00> or ->\n", + "- Use prepared statements with parameterized queries to prevent SQL injection\n", + "\n", + "\n", + "### Input:\n", + "For SQL TABLE 'sleep_health_eda' with sample question/answer pairs,\n", + "(), create a valid SQL (dialect:sqlite) query to answer the following question:\n", + "What is the average sleep duration for each gender?.\n", + "This query will run on a database whose schema is represented in this string:\n", + "CREATE TABLE 'sleep_health_eda' (['Person_ID NUMERIC, Gender TEXT, Age NUMERIC, Occupation TEXT, Sleep_Duration NUMERIC, Quality_of_Sleep NUMERIC, Physical_Activity_Level NUMERIC, Stress_Level NUMERIC, BMI_Category TEXT, Blood_Pressure TEXT, Heart_Rate NUMERIC, Daily_Steps NUMERIC, Sleep_Disorder TEXT,']\n", + ");\n", + "\n", + "-- Table 'sleep_health_eda', , has sample values ({'sleep_health_eda': [\"'Gender' contains values similar to Male,Female.\", \"'Occupation' contains values similar to Lawyer,Teacher,Doctor,Software Engineer,Scientist,Sales Representative,Accountant,Salesperson,Manager,Nurse.\", \"'BMI_Category' contains values similar to Overweight,Normal,Obese,Normal Weight.\", \"'Sleep_Disorder' contains values similar to None,Sleep Apnea,Insomnia.\"]})\n", + "\n", + "### Response:\n", + "Based on your instructions, here is the SELECT SQL query I have generated to answer the question `What is the average sleep duration for each gender?`:\n", + "```SELECT\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.785\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m590\u001b[0m - \u001b[1mContext length: 743\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.787\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m677\u001b[0m - \u001b[1mRegeneration with options requested on previous query ...\u001b[0m\n", + "\u001b[32m2024-01-27 20:39:56.788\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m692\u001b[0m - \u001b[34m\u001b[1mSelected temperature for diverse beam search: 0.4\u001b[0m\n", + "/home/pramit/.jupyterven/lib/python3.8/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.4` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", + " warnings.warn(\n", + "/home/pramit/.jupyterven/lib/python3.8/site-packages/transformers/generation/configuration_utils.py:407: UserWarning: `do_sample` is set to `False`. However, `top_k` is set to `5` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_k`.\n", + " warnings.warn(\n", + "Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n", + "\u001b[32m2024-01-27 20:42:01.434\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m727\u001b[0m - \u001b[1mGenerated options:\n", + "\u001b[0m\n", + "\u001b[32m2024-01-27 20:42:01.440\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", + "Option 1: (_probability_: 0.381034255027771)\n", + "``` sql\n", + "SELECT gender,\n", + " AVG(sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\u001b[0m\n", + "\u001b[32m2024-01-27 20:42:01.444\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", + "Option 2: (_probability_: 0.2624567449092865)\n", + "``` sql\n", + "SELECT AVG(sleep_duration) AS average_sleep_duration,\n", + " gender\n", + "FROM sleep_health_eda\n", + "GROUP BY gender\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\u001b[0m\n", + "\u001b[32m2024-01-27 20:42:01.446\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", + "Option 3: (_probability_: 0.22498156130313873)\n", + "``` sql\n", + "SELECT Gender,\n", + " AVG(Sleep_Duration) AS average_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY Gender\n", + "ORDER BY average_duration DESC NULLS LAST\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\u001b[0m\n", + "\u001b[32m2024-01-27 20:42:01.451\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", + "Option 4: (_probability_: 0.13085876405239105)\n", + "``` sql\n", + "SELECT 'Gender',\n", + " AVG('Sleep_Duration') AS average_sleep_duration\n", + "FROM'sleep_health_eda'\n", + "GROUP BY 'Gender'\n", + "ORDER BY average_sleep_duration DESC NULLS LAST\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\u001b[0m\n", + "\u001b[32m2024-01-27 20:42:01.474\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.query\u001b[0m:\u001b[36mgenerate_sql\u001b[0m:\u001b[36m748\u001b[0m - \u001b[1mAlternate options:\n", + "Option 5: (_probability_: 0.0006686743581667542)\n", + "``` sql\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\u001b[0m\n", + "\u001b[32m2024-01-27 20:42:01.488\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m608\u001b[0m - \u001b[1mInput query: What is the average sleep duration for each gender?\u001b[0m\n", + "\u001b[32m2024-01-27 20:42:01.489\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m609\u001b[0m - \u001b[1mGenerated response:\n", + "\n", + "SELECT \"gender\", AVG(\"sleep_duration\") AS \"average_sleep_duration\" FROM \"sleep_health_eda\" GROUP BY \"gender\" ORDER BY \"average_sleep_duration\" DESC LIMIT 100\u001b[0m\n", + "\u001b[32m2024-01-27 20:42:01.492\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36msidekick.prompter\u001b[0m:\u001b[36mask\u001b[0m:\u001b[36m635\u001b[0m - \u001b[1mAlternate responses:\n", + "\n", + "['Option 1: (_probability_: 0.381034255027771)\\n``` sql\\nSELECT gender,\\n AVG(sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY gender\\nORDER BY average_sleep_duration DESC NULLS LAST\\nLIMIT 100;\\n```\\n\\n\\n', 'Option 2: (_probability_: 0.2624567449092865)\\n``` sql\\nSELECT AVG(sleep_duration) AS average_sleep_duration,\\n gender\\nFROM sleep_health_eda\\nGROUP BY gender\\nLIMIT 100;\\n```\\n\\n\\n', 'Option 3: (_probability_: 0.22498156130313873)\\n``` sql\\nSELECT Gender,\\n AVG(Sleep_Duration) AS average_duration\\nFROM sleep_health_eda\\nGROUP BY Gender\\nORDER BY average_duration DESC NULLS LAST\\nLIMIT 100;\\n```\\n\\n\\n', \"Option 4: (_probability_: 0.13085876405239105)\\n``` sql\\nSELECT 'Gender',\\n AVG('Sleep_Duration') AS average_sleep_duration\\nFROM'sleep_health_eda'\\nGROUP BY 'Gender'\\nORDER BY average_sleep_duration DESC NULLS LAST\\nLIMIT 100;\\n```\\n\\n\\n\", 'Option 5: (_probability_: 0.0006686743581667542)\\n``` sql\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\\nFROM sleep_health_eda\\nGROUP BY sleep_health_eda.gender\\nORDER BY average_sleep_duration DESC NULLS LAST;\\n\\nSELECT sleep_health_eda.gender,\\n AVG(sleep_health_eda\\nLIMIT 100;\\n```\\n\\n\\n']\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exiting...\n" + ] + } + ], + "source": [ + "# Alternate options\n", + "res = query(\"What is the average sleep duration for each gender?\", table_name=\"sleep_health_eda\", \n", + " table_info_path=table_info_path, sample_qna_path=None, regenerate_with_options=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e415c0b9-466e-4417-ac1e-493914a83c36", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Question = **Generated response for question,**\n", + "What is the average sleep duration for each gender?\n", + "\n", + "----Options----\n", + "Option 1: (_probability_: 0.381034255027771)\n", + "``` sql\n", + "SELECT gender,\n", + " AVG(sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\n", + "Option 2: (_probability_: 0.2624567449092865)\n", + "``` sql\n", + "SELECT AVG(sleep_duration) AS average_sleep_duration,\n", + " gender\n", + "FROM sleep_health_eda\n", + "GROUP BY gender\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\n", + "Option 3: (_probability_: 0.22498156130313873)\n", + "``` sql\n", + "SELECT Gender,\n", + " AVG(Sleep_Duration) AS average_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY Gender\n", + "ORDER BY average_duration DESC NULLS LAST\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\n", + "Option 4: (_probability_: 0.13085876405239105)\n", + "``` sql\n", + "SELECT 'Gender',\n", + " AVG('Sleep_Duration') AS average_sleep_duration\n", + "FROM'sleep_health_eda'\n", + "GROUP BY 'Gender'\n", + "ORDER BY average_sleep_duration DESC NULLS LAST\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\n", + "Option 5: (_probability_: 0.0006686743581667542)\n", + "``` sql\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda.sleep_duration) AS average_sleep_duration\n", + "FROM sleep_health_eda\n", + "GROUP BY sleep_health_eda.gender\n", + "ORDER BY average_sleep_duration DESC NULLS LAST;\n", + "\n", + "SELECT sleep_health_eda.gender,\n", + " AVG(sleep_health_eda\n", + "LIMIT 100;\n", + "```\n", + "\n", + "\n", + "\n" + ] + } + ], + "source": [ + "print(f\"Question = {res[0][0]}\")\n", + "print(\"----Options----\")\n", + "for _r in res[1]:\n", + " print(_r)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6308230-3595-400e-8f85-0a67bcdca029", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/sleep_eda/table_info.jsonl b/examples/sleep_eda/table_info.jsonl new file mode 100644 index 0000000..04fc5b7 --- /dev/null +++ b/examples/sleep_eda/table_info.jsonl @@ -0,0 +1,13 @@ +{"Column Name": "Person_ID", "Column Type": "uuid PRIMARY KEY"} +{"Column Name": "Gender", "Column Type": "TEXT", "Sample Values": ["Female", "Male"]} +{"Column Name": "Age", "Column Type": "NUMERIC"} +{"Column Name": "Occupation", "Column Type": "TEXT", "Sample Values": ["Accountant", "Doctor", "Engineer", "Lawyer","Manager", "Nurse", "Sales Representative", "Salesperson", "Scientist", "Software Engineer", "Teacher"]} +{"Column Name": "Sleep_Duration", "Column Type": "NUMERIC"} +{"Column Name": "Quality_of_Sleep", "Column Type": "NUMERIC"} +{"Column Name": "Physical_Activity_Level", "Column Type": "NUMERIC"} +{"Column Name": "Stress_Level", "Column Type": "NUMERIC"} +{"Column Name": "BMI_Category", "Column Type": "TEXT", "Sample Values": ["Normal", "Normal Weight", "Obese", "Overweight"]} +{"Column Name": "Blood_Pressure", "Column Type": "TEXT", "Sample Values": ["115/75", "115/78", "117/76", "118/75", "118/76", "119/77"]} +{"Column Name": "Heart_Rate", "Column Type": "NUMERIC"} +{"Column Name": "Daily_Steps", "Column Type": "NUMERIC"} +{"Column Name": "Sleep_Disorder", "Column Type": "TEXT", "Sample Values": ["Insomnia", "Sleep Apnea"]} diff --git a/examples/telemetry/samples.csv b/examples/telemetry/samples.csv new file mode 100644 index 0000000..6f2f598 --- /dev/null +++ b/examples/telemetry/samples.csv @@ -0,0 +1,27 @@ +query,answer +Total number of CPUs used?,SELECT sum((payload->'engineEvent'-> 'pausing' -> 'engine'->> 'cpu')::integer) AS total_cpus_used FROM telemetry WHERE payload->'engineEvent'-> 'pausing' -> 'engine'->> 'cpu' IS NOT NULL; +Find the number of AI units for each user using stream for each resource type (overall),"SELECT user_id, user_name, resource_type, date_trunc('day', ts) as start_day, + sum(AI_units) as AI_units FROM ( + SELECT user_id, user_name, resource_type, ts, + extract(epoch from ts - lag(ts) over (partition by user_id, resource_type order by ts)) / 3600 AS AI_units + FROM telemetry + WHERE stream = 'running' + ) sub GROUP BY user_id, user_name, resource_type, start_day +ORDER BY start_day DESC NULLS LAST;" +Compute global usage over time,"SELECT + ts AS time_interval, + GREATEST((GREATEST((ram_gi / 64.0), (cpu / 8.0)) - gpu), 0) + (gpu * 4.0) as ai_units +FROM ( + SELECT + -- This is a gauge stream, meaning multiple sources are exporting duplicate entries during the same hour interval + ts, + -- RAM usage in Gi + COALESCE(((payload->'usageGauge'->'billingResources'->>'paddedMemoryReservationBytes')::bigint/1024.0/1024.0/1024.0), 0) AS ram_gi, + -- CPU usage in vCPU + COALESCE(((payload->'usageGauge'->'billingResources'->'paddedCpuReservationMillicpu')::int/1000.0), 0) AS cpu, + -- GPU usage in number of GPUs + COALESCE(((payload->'usageGauge'->'billingResources'->'gpuCount')::int), 0) AS gpu + FROM telemetry + WHERE stream = 'gauage_resources' +) AS internal +ORDER BY 1, 2 DESC;" \ No newline at end of file diff --git a/examples/telemetry/table_info.jsonl b/examples/telemetry/table_info.jsonl new file mode 100644 index 0000000..9d957ba --- /dev/null +++ b/examples/telemetry/table_info.jsonl @@ -0,0 +1,10 @@ +{"Column Name": "id", "Column Type": "uuid PRIMARY KEY"} +{"Column Name": "ts", "Column Type": "TIMESTAMP WITH TIME ZONE NOT NULL"} +{"Column Name": "kind", "Column Type": "TEXT NOT NULL, -- or int?", "Sample Values": ["EVENT"]} +{"Column Name": "user_id", "Column Type": "TEXT"} +{"Column Name": "user_name", "Column Type": "TEXT"} +{"Column Name": "resource_type", "Column Type": "TEXT NOT NULL, -- or int?", "Sample Values": ["FEATURE_STORE", "PROJECT", "MLOPS_EXPERIMENT", "APP", "APP_INSTANCE", "MLOPS_DEPLOYMENT", "MLOPS_DATASET", "MLOPS_USER", "RESOURCE_TYPE_UNSPECIFIED", "SCORING", "DAI_ENGINE", "MLOPS_MODEL"]} +{"Column Name": "resource_id", "Column Type": "TEXT"} +{"Column Name": "stream", "Column Type": "TEXT NOT NULL", "Sample Values": ["air/h2o/cloud/mlops/deployment/created", "ai/h2o/cloud/appstore/instance/gauge/running", "ai/h2o/cloud/mlops/project/unshared", "ai/h2o/cloud/mlops/gauge/project", "ai/h2o/cloud/appstore/user/event/login", "ai/h2o/cloud/mlops/gauge/registered-model-version", "ai/h2o/cloud/appstore/instance/event/started", "ai/h2o/cloud/mlops/deployment/deleted", "ai/h2o/cloud/mlops/gauge/dataset", "ai/h2o/cloud/fs/job/running", "ai/h2o/engine/event/paused", "ai/h2o/cloud/mlops/project/deleted", "ai/h2o/engine/event/deleting", "ai/h2o/engine/event/pausing", "ai/h2o/cloud/mlops/gauge/deployment", "ai/h2o/cloud/usage/global/gauge/resources", "ai/h2o/cloud/mlops/gauge/registered-model", "ai/h2o/cloud/appstore/instance/event/suspended", "ai/h2o/cloud/usage/namespace/gauge/resources", "ai/h2o/cloud/mlops/registered-model-version/created", "ai/h2o/cloud/mlops/project/created", "ai/h2o/cloud/mlops/project/shared", "ai/h2o/cloud/mlops/experiment/created", "ai/h2o/cloud/mlops/dataset/created", "ai/h2o/cloud/appstore/app/event/created", "ai/h2o/cloud/appstore/instance/event/terminated", "ai/h2o/cloud/mlops/gauge/user", "ai/h2o/engine/event/starting", "ai/h2o/cloud/mlops/event/scoring-result/created", "ai/h2o/engine/event/running", "ai/h2o/cloud/fs/job/submitted", "ai/h2o/cloud/mlops/registered-model/created", "ai/h2o/cloud/mlops/gauge/experiment", "ai/h2o/document/ai/proxy", "ai/h2o/cloud/mlops/experiment/unlinked", "ai/h2o/cloud/fs/job/finished", "ai/h2o/cloud/appstore/app/event/deleted", "ai/h2o/cloud/appstore/instance/event/resumed"]} +{"Column Name": "source", "Column Type": "TEXT NOT NULL"} +{"Column Name": "payload", "Column Type": "jsonb NOT NULL", "Sample Values":[{"engineEvent": {"pausing": {"engine": {"cpu": "1", "memory": "1", "gpu": "0"}}}}]} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3c210b9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,90 @@ +[tool.poetry] +name = "sql-sidekick" +version = "0.2.2" +license = "Apache-2.0 license" +description = "An AI assistant for SQL generation" +authors = [ + "Pramit Choudhary ", + "Michal Malohlava " +] +readme = "README.md" +classifiers = [ + "Development Status :: Alpha", + "Environment :: CLI", + "Intended Audience :: Developers, Analysts", + "License :: Apache-2.0 license", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8+", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence" +] +packages = [{include = "sidekick"}] + +[tool.poetry.dependencies] +python = ">=3.8.1,<=3.10" +pandas = "^1.3.3" +numpy = "^1.21.2" +click = "^8.0.1" +sqlalchemy = "^2.0.25" +psycopg2-binary = "^2.9.6" +colorama = "^0.4.6" +llama_index = "^0.9.0" +loguru = "^0.7.0" +toml = "^0.10.2" +sqlglot = "^12.2.0" +sqlparse = "^0.4.4" +transformers = "^4.29.0" +sentence-transformers = "^2.2.2" +torch = "2.0.1" +sqlalchemy-utils = "^0.41.1" +h2o-wave = "0.26.3" +pandasql = "0.7.3" +accelerate = "0.21.0" +bitsandbytes = "0.41.0" +InstructorEmbedding = "^1.0.1" +h2ogpte = "1.2.7" +openai = "^1.3.0" +langchain-community = "^0.0.13" +python-dotenv = "^1.0.0" +databricks-sql-connector = "^3.0.1" + +[tool.poetry.scripts] +sql-sidekick = "sidekick.prompter:cli" + +[tool.poetry.dev-dependencies] +pylint = { version = "^2.12.2", allow-prereleases = true } +flake8 = { version = "^4.0.1", allow-prereleases = true } +black = { version = "21.12b0", allow-prereleases = true } +pytest = { version = "^7.4.4", allow-prereleases = true } + +[tool.black] +line-length = 120 +skip-string-normalization = true +target-version = ['py38.16'] +include = '\.pyi?$' +exclude = ''' +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ + | notebooks + | local +) +''' + +[tool.isort] +line_length = 120 +multi_line_output = 3 + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9599e5e --- /dev/null +++ b/requirements.txt @@ -0,0 +1,112 @@ +accelerate==0.21.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +aiohttp==3.9.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +aiosignal==1.3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +ansicon==1.89.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows" +anyio==4.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +async-timeout==4.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +attrs==23.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +beautifulsoup4==4.12.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +bitsandbytes==0.41.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +blessed==1.20.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +bs4==0.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +certifi==2023.11.17 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +charset-normalizer==3.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +click==8.1.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +colorama==0.4.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +databricks-sql-connector==3.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +dataclasses-json==0.6.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +deprecated==1.2.14 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +distro==1.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +editor==1.6.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +et-xmlfile==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +exceptiongroup==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +filelock==3.13.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +frozenlist==1.4.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +fsspec==2023.12.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +greenlet==3.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +h11==0.14.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +h2o-wave==0.26.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +h2ogpte==1.2.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +httpcore==0.17.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +httpx==0.24.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +huggingface-hub==0.20.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +idna==3.6 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +inquirer==3.2.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +instructorembedding==1.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +jinja2==3.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +jinxed==1.2.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and platform_system == "Windows" +joblib==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +jsonpatch==1.33 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +jsonpointer==2.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +langchain-community==0.0.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +langchain-core==0.1.11 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +langsmith==0.0.81 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +llama-index==0.9.32 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +loguru==0.7.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +lxml==4.9.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +lz4==4.3.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +markupsafe==2.1.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +marshmallow==3.20.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +mpmath==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +multidict==6.0.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +mypy-extensions==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +nest-asyncio==1.5.9 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +networkx==3.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +nltk==3.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +numpy==1.24.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +oauthlib==3.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +openai==1.8.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +openpyxl==3.1.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +packaging==23.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pandas==1.5.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pandasql==0.7.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pillow==10.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +psutil==5.9.7 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +psycopg2-binary==2.9.9 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pyarrow==14.0.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pydantic==1.10.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pydantic[dotenv]==1.10.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +python-dateutil==2.8.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +python-dotenv==1.0.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pytz==2023.3.post1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +pyyaml==6.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +readchar==4.0.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +regex==2023.12.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +requests==2.31.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +runs==1.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +safetensors==0.4.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +scikit-learn==1.3.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +scipy==1.10.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sentence-transformers==2.2.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sentencepiece==0.1.99 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +setuptools==69.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +six==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sniffio==1.3.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +soupsieve==2.5 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sqlalchemy-utils==0.41.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sqlalchemy==2.0.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sqlalchemy[asyncio]==2.0.25 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sqlglot==12.4.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sqlparse==0.4.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +starlette==0.35.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +sympy==1.12 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +tenacity==8.2.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +threadpoolctl==3.2.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +thrift==0.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +tiktoken==0.5.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +tokenizers==0.15.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +toml==0.10.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +torch==2.0.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +torchvision==0.15.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +tqdm==4.66.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +transformers==4.36.2 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +typing-extensions==4.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +typing-inspect==0.9.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +urllib3==2.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +uvicorn==0.26.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +wcwidth==0.2.13 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +websockets==11.0.3 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +win32-setctime==1.1.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" and sys_platform == "win32" +wrapt==1.16.0 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +xmod==1.8.1 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" +yarl==1.9.4 ; python_full_version >= "3.8.1" and python_full_version <= "3.10.0" diff --git a/sidekick/__init__.py b/sidekick/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sidekick/configs/__init__.py b/sidekick/configs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sidekick/configs/data_template.py b/sidekick/configs/data_template.py new file mode 100644 index 0000000..0f7f662 --- /dev/null +++ b/sidekick/configs/data_template.py @@ -0,0 +1,11 @@ +# Reference: https://github.com/openai/openai-cookbook/blob/main/examples/Backtranslation_of_SQL_queries.py +question_query_samples = """ +{ + "question": "{}", + "query": "{}" +} +""" + +schema_info_template = {"Column Name": "", "Column Type": "", "Sample Values": []} + +data_samples_template = "'{column_name}' contains values similar to {comma_separated_sample_values}." diff --git a/sidekick/configs/env.toml b/sidekick/configs/env.toml new file mode 100644 index 0000000..82ed886 --- /dev/null +++ b/sidekick/configs/env.toml @@ -0,0 +1,32 @@ +[MODEL_INFO] +OPENAI_API_KEY = "" # Needed only for openAI models +MODEL_NAME = "h2ogpt-sql-sqlcoder-34b-alpha" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003 +QUANT_TYPE = '4bit' + +H2OGPT_URL = 'http://38.128.233.247' +H2OGPT_API_TOKEN = "" +H2OGPTE_URL = "" +H2OGPTE_API_TOKEN = "" + +RECOMMENDATION_MODEL = "h2oai/h2ogpt-4096-llama2-70b-chat" +VULNERABILITY_SCANNER = "h2oai/h2ogpt-4096-llama2-70b-chat" # other options openai models depending on availability (e.g. 'gpt-3.5-turbo') +SELF_CORRECTION_MODEL = "h2oai/h2ogpt-4096-llama2-70b-chat" # other options openai models depending on availability (e.g. 'gpt-3.5-turbo') + +[LOCAL_DB_CONFIG] +HOST_NAME = "localhost" +USER_NAME = "sqlite" +PASSWORD = "abc" +DB_NAME = "querydb" +PORT = "5432" + +[LOGGING] +LOG-LEVEL = "DEBUG" + +[DB-DIALECT] +DB_TYPE = "sqlite" + +[TABLE_INFO] +TABLE_INFO_PATH = "examples/demo/table_info.jsonl" +SAMPLE_QNA_PATH = "examples/demo/demo_qa.csv" +TABLE_SAMPLES_PATH = "examples/demo/demo_data.csv" +TABLE_NAME = "demo" diff --git a/sidekick/configs/prompt_template.py b/sidekick/configs/prompt_template.py new file mode 100644 index 0000000..2a254cd --- /dev/null +++ b/sidekick/configs/prompt_template.py @@ -0,0 +1,149 @@ +# Chain of thought for reasoning and task decomposition +# Reference: https://arxiv.org/pdf/2201.11903.pdf +TASK_PROMPT = { + "system_prompt": "Act as a Data Analyst", + "user_prompt": """ + ### For table {_table_name}, given an input *Question*, let's work it out in a detailed step by step way and only return specific, detailed and informative tasks as an ordered numeric list for SQL generation to be sure we have the right answer. + Use values that are explicitly mentioned in the *Question*. + Use the *History* and *Context* section for co-reference and to infer relationships and identify column names. *Context* contains entity mapping containing keys:values. + If the words in the *Question* do not match column names *Data* section; Search for them in *Context* section. + Always use *Context* with highest similarity score with the *Question*. + If words in the *Question* match more than one key, include both the values using "or" when forming step by step tasks. + If no information related to the *Question* is found; apply self reasoning and predict for possible tasks. + Infer the return type of the Question. + DO NOT generate SQL response, only return itemized tasks. + # *Data:* \nFor table {_table_name} schema info is mentioned below,\n{_data_info} + # *History*: \n{_sample_queries} + # *Question*: For table {_table_name}, {_question_str}, *Context*: {_context} + # Output: Tasks: ordered numeric list of tasks + """, +} + +# Few shot learning prompt +## Task Description +## Examples +## Prompt +# Reference: https://arxiv.org/pdf/2005.14165.pdf +QUERY_PROMPT = """ + ### System: Act as a SQL Expert + # For table {_table_name}, given an input *Question*, only generate syntactically correct {dialect} SQL queries. + # Let's work it out in a detailed step by step way using the reasoning from *Tasks* section. + # Pick the SQL query which has the highest average log probability if more than one result is likely to answer the + candidate *Question*. + ### {dialect} SQL tables + ### *Data:* \nFor table {_table_name} schema info is mentioned below,\n{_data_info} + ### *History*:\n{_sample_queries} + ### *Question*: For table {_table_name}, {_question} + # SELECT 1 + ### *Plan for table {_table_name}*:\n{_tasks} + ### *Policies for SQL generation*: + # Avoid overly complex SQL queries, favor concise human readable SQL queries which are easy to understand and debug + # Avoid patterns that might be vulnerable to SQL injection + # Use values and column names that are explicitly mentioned in the question or in the *Data* section. + # DO NOT query for columns that do not exist + # Validate column names with the table name when needed + # DO NOT USE aggregate and window function together + # Use COUNT(1) instead of COUNT(*) + # Return with LIMIT 100 + # Prefer NOT EXISTS to LEFT JOIN ON null id + # Avoid using the WITH statement + # When using DESC keep NULLs at the end + # Always cast the numerator as float when computing ratios + # If JSONB format found in Table schema, do pattern matching on keywords from the question and use SQL functions such as ->> or -> + # Use prepared statements with parameterized queries to prevent SQL injection + # Add explanation and reasoning for each SQL query + """ + +DEBUGGING_PROMPT = { +"system_prompt": "Act as a SQL expert for {dialect} database", +"user_prompt": """ +Help fix the provided incorrect SQL Query mentioned below in the *Query* section",\n +### Error: {ex_traceback}\n +### Query:\n {qry_txt}\n\n +### Output: Add '```sql' as prefix and '```' as suffix to generated SQL +""", +} + +GUARDRAIL_PROMPT = { +"system_prompt": "Act as a Security expert your job is to detect SQL injection vulnerabilities", +"user_prompt":""" +Help audit SQL injection patterns within the provided the SQL *Query*. +Flag as vulnerable if there are any known SQL injection string pattern is found in the *Query*, few *Examples* are provided below, +### *Examples*:\n +1. SELECT * FROM sleep_health_and_lifestyle_study WHERE UserId = 105; vulnerable: false +2. SELECT * FROM sleep_health_and_lifestyle_study WHERE UserId = 105 OR 1=1; vulnerable: true +\n +Only SELECT queries are allowed, flag as vulnerable if other SQL statements are found in the *Query* (e.g. DROP, INSERT, UPDATE, DELETE, etc.). +If there are more than one possible vulnerabilities, summarize in a single explanation.\n +### Query:\n {query_txt}\n\n +### Output: Return result as a valid dictionary string using the JSON schema format, don't add a separate Explanation section or after the json schema, \n{schema} +""" +} + +NSQL_QUERY_PROMPT = """ +For {dialect} SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}) + +CREATE TABLE '{table_name}'({column_info} +) + +Table '{table_name}' has sample values ({data_info_detailed}) + + + +-- Using valid and syntactically correct {dialect} SQL syntax, answer the following questions (check for typos, grammatical and spelling errors and fix them) with the information for '{table_name}' provided above; for final SQL only use column names from the CREATE TABLE (Do not query for columns that do not exist). + + +-- Using reference for TABLES '{table_name}' {context}; {question_txt}? + +SELECT""" + +# https://colab.research.google.com/drive/13BIKsqHnPOBcQ-ba2p77L5saiepTIwu0#scrollTo=0eI-VpCkf-fN +STARCODER2_PROMPT = """ +### Instructions: +Your task is convert a question into a valid {dialect} syntax SQL query, given a {dialect} database schema. Let's work this out step by step to be sure we have the right answer. +Only use the column names from the CREATE TABLE statement. +Adhere to these rules: +- **Deliberately go through the question and database schema word by word** to appropriately answer the question +- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. +- Only use supplied table names: **{table_name}** for generation +- Only use column names from the CREATE TABLE statement: **{column_info}** for generation. DO NOT USE any other column names outside of this. +- Avoid overly complex SQL queries, favor concise human readable SQL queries which are easy to understand and debug +- Avoid patterns that might be vulnerable to SQL injection, e.g. use proper sanitization and escaping for raw user input +- Always cast the numerator as float when computing ratios +- Always use COUNT(1) instead of COUNT(*) +- If the question is asking for a rate, use COUNT to compute percentage +- Avoid using the WITH statement +- DO NOT USE aggregate and window function together +- Prefer NOT EXISTS to LEFT JOIN ON null id +- When using DESC keep NULLs at the end +- If JSONB format found in Table schema, do pattern matching on keywords from the question and use SQL functions such as ->> or -> +- Use prepared statements with parameterized queries to prevent SQL injection + + +### Input: +For SQL TABLE '{table_name}' with sample question/answer pairs,\n({sample_queries}), create a valid SQL (dialect:{dialect}) query to answer the following question:\n{question_txt}. +This query will run on a database whose schema is represented in this string: +CREATE TABLE '{table_name}' ({column_info} +); + +-- Table '{table_name}', {context}, has sample values ({data_info_detailed}) + +### Response: +Based on your instructions, here is the SELECT SQL query I have generated to answer the question `{question_txt}`: +```SELECT""" + + +RECOMMENDATION_PROMPT=""" +Generate {n_questions} simple questions for the given dataset. +Only use the specified column names mentioned in *Data Schema*. + +### Data Schema: +{data_schema} + + +Output: ordered numeric list of questions + + +### Response: +1. +""" diff --git a/sidekick/db_config.py b/sidekick/db_config.py new file mode 100644 index 0000000..2dac460 --- /dev/null +++ b/sidekick/db_config.py @@ -0,0 +1,295 @@ +# create db with supplied info +import json +from pathlib import Path +from typing import Any, Optional + +import pandas as pd +import sqlalchemy +from langchain_community.utilities import SQLDatabase +from psycopg2.extras import Json +from sidekick.configs.data_template import data_samples_template +from sidekick.logger import logger +from sidekick.schema_generator import generate_schema +from sqlalchemy import create_engine, text +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.schema import CreateTable +from sqlalchemy_utils import database_exists + + +class DBConfig: + dialect = "sqlite" + _url = None + db_name = "default" + + def __init__( + self, + hostname, + user_name, + password, + port, + base_path, + schema_info_path=None, + schema_info=None + ) -> None: + self.hostname = hostname + self.user_name = user_name + self.password = password + self.port = port + self._table_name = None + self.schema_info_path = schema_info_path + self.schema_info = schema_info + self._engine = None + self.base_path = base_path + self.column_names = [] + + if self.dialect == "sqlite": + self._url = f"sqlite:///{base_path}/db/sqlite/{self.db_name}.db" + elif self.dialect == "postgresql": + self._url = f"{self.dialect}://{self.user_name}:{self.password}@{self.hostname}:{self.port}/" + else: + self._url = None # currently databricks is initialized _get_raw_table_schema + DBConfig._url = self._url + + @property + def table_name(self): + return self._table_name + + @table_name.setter + def table_name(self, val): + self._table_name = val.lower().replace(" ", "_") + + @property + def engine(self): + return self._engine + + def db_exists(self): + if self.dialect == "sqlite": + engine = create_engine(f"{self._url}", echo=True) + else: + engine = create_engine(f"{self._url}{self.db_name}", echo=True) + return database_exists(f"{engine.url}") + + @classmethod + def _get_raw_table_schema(cls, **config_args:Any): + if cls.dialect == "databricks": + _catalog = config_args.get("catalog", "samples") + _schema = config_args.get("schema", "default") + _cluster_id = config_args.get("cluster_id", None) + db = SQLDatabase.from_databricks(catalog=_catalog, schema=_schema, cluster_id=_cluster_id) + tbl = [_t for _t in db._metadata.sorted_tables if _t.name == cls.table_name.lower()][0] + cls.engine = db._engine + cls._url = db._engine.url + # TODO pending sqlite/postgresql + create_table_info = CreateTable(tbl).compile(cls.engine) if tbl is not None else '' + return str(create_table_info).strip() + + @classmethod + def get_column_info(cls, output_path: str, engine_format:bool=True, **config_args:Any): + # Getting raw info should help in getting all relevant information about the columns including - foreign keys, primary keys, etc. + raw_info = cls._get_raw_table_schema(**config_args) + c_info = [_c.strip().split("\n)")[0] for _c in raw_info.split("(\n\t")[1].split(",")[:-1]] + c_info_dict = dict([(f"{_c.split(' ')[0]}", _c.split(' ')[1]) for _c in c_info]) + column_info, output_path = generate_schema(output_path=output_path, column_info=c_info_dict) if engine_format else (c_info_dict, None) + return column_info, output_path + + def create_db(self): + engine = create_engine(self._url) + self._engine = engine + try: + with engine.connect() as conn: + # conn.execute("commit") + # Do not substitute user-supplied database names here. + if self.dialect != "sqlite": + conn.execute("commit") + res = conn.execute(f"CREATE DATABASE {self.db_name}") + self._url = f"{self._url}{self.db_name}" + return res, None + else: + logger.debug("SQLite DB is created successfully.") + + return True, None + except SQLAlchemyError as sqla_error: + logger.debug("SQLAlchemy error:", sqla_error) + return None, sqla_error + except Exception as error: + logger.debug("Error Occurred:", error) + return None, error + + + def _parser(self, file_handle=None, schema_info=None): + sample_values = [] + res = [] + _lines = file_handle if file_handle else schema_info + for line in _lines: + data = json.loads(line) if isinstance(line, str) and line.strip() else line + if "Column Name" in data and "Column Type" in data: + col_name = data["Column Name"] + self.column_names.append(col_name) + col_type = data["Column Type"] + if col_type.lower() == "text": + col_type = col_type + " COLLATE NOCASE" + # if column has sample values, save in cache for future use. + if "Sample Values" in data: + _sample_values = data["Sample Values"] + _ds = data_samples_template.format( + column_name=col_name, + comma_separated_sample_values=",".join( + str(_sample_val) for _sample_val in _sample_values + ), + ) + sample_values.append(_ds) + _new_samples = f"{col_name} {col_type}" + res.append(_new_samples) + return res, sample_values + + + def _extract_schema_info(self, schema=None, schema_path=None): + # From jsonl format + # E.g. {"Column Name": "id", "Column Type": "uuid PRIMARY KEY"} + res = [] + sample_values = [] + try: + if schema is not None: + logger.debug(f"Using passed schema information.") + res, sample_values = self._parser(schema_info=schema) + else: + if schema_path is None: + table_info_file = f"{self.base_path}/var/lib/tmp/data/table_context.json" + if Path(table_info_file).exists(): + with open(table_info_file, "w") as outfile: + schema_path = json.load(outfile)["schema_info_path"] + if Path(schema_path).exists(): + logger.debug(f"Using schema information from: {schema_path}") + with open(schema_path, "r") as in_file: + res, sample_values = self._parser(file_handle=in_file) + if len(sample_values) > 0: + # cache it for future use + with open( + f"{self.base_path}/var/lib/tmp/data/{self._table_name}_column_values.json", "w" + ) as outfile: + json.dump(sample_values, outfile, indent=2, sort_keys=False) + except ValueError as ve: + logger.error(f"Error in reading table context file: {ve}") + pass + return res + + def create_table(self, schema_info_path=None, schema_info=None): + try: + engine = create_engine(self._url, isolation_level="AUTOCOMMIT") + self._engine = engine + if self.schema_info is None and schema_info_path: + # If schema information is not provided, extract from the template. + self.schema_info = """,\n""".join(self._extract_schema_info(schema_path=schema_info_path)).strip() + else: + self.schema_info = """,\n""".join(self._extract_schema_info(schema=schema_info)).strip() + + logger.debug(f"Schema info used for creating table:\n {self.schema_info}") + # Currently, multiple tables is not supported. + # TODO https://github.com/h2oai/sql-sidekick/issues/62 + create_syntax = f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + {self.schema_info} + ) + """ + with engine.connect() as conn: + if self.dialect != "sqlite": + conn.execute("commit") + conn.execute(text(create_syntax)) + logger.info(f"Table created: {self.table_name}") + return self.table_name, None + except SQLAlchemyError as sqla_error: + logger.debug("SQLAlchemy error:", sqla_error) + return None, sqla_error + except Exception as error: + logger.debug("Error Occurred:", error) + return None, error + + def has_table(self): + engine = create_engine(self._url) + return sqlalchemy.inspect(engine).has_table(self.table_name) + + def data_preview(self, table_name): + if table_name: + query_str = f"SELECT * FROM {table_name} LIMIT 10" + result = self.execute_query(query_str) + else: + result = "Table not found. Make sure uploaded dataset is registered." + return result + + def add_samples(self, data_csv_path=None): + conn_str = self._url + try: + logger.debug(f"Adding sample values to table: {data_csv_path}") + df_chunks = pd.read_csv(data_csv_path, chunksize=5000) + engine = create_engine(conn_str, isolation_level="AUTOCOMMIT") + + for idx, chunk in enumerate(df_chunks): + # Write rows to database + logger.debug(f"Inserting chunk: {idx}") + chunk.columns = self.column_names + # Make sure column names in the data-frame match the schema + chunk.to_sql(self.table_name, engine, if_exists="append", index=False, method="multi") + + logger.info(f"Data inserted into table: {self.table_name}") + # Fetch the number of rows from the table + sample_query = f"SELECT COUNT(*) AS ROWS FROM {self.table_name} LIMIT 1" + num_rows = pd.DataFrame(engine.connect().execute(text(sample_query))) + res = num_rows.values[0][0] if not num_rows.empty else 0 + logger.info(f"Number of rows inserted: {res}") + engine.dispose() + return res, None + except (SQLAlchemyError, AttributeError) as sqla_error: + logger.debug("SQLAlchemy error:", sqla_error) + return None, sqla_error + except Exception as error: + logger.debug("Error Occurred:", error) + return None, error + finally: + if engine: + engine.dispose() + + @classmethod + def execute_query(cls, query=None, n_rows=100): + output = [] + if cls.dialect == "sqlite" or cls.dialect == "databricks": + conn_str = cls._url + elif cls.dialect == "postgresql": + conn_str = f"{cls._url}{cls.db_name}" + else: + conn_str = None + + # Create an engine + engine = create_engine(conn_str) + # Create a connection + connection = engine.connect() + + try: + if query: + logger.debug(f"Executing query:\n {query}") + _query = text(query) + result = connection.execute(_query) + + # Process the query results + cnt = 0 + for row in result: + if cnt <= n_rows: + # Access row data using row[column_name] + output.append(row) + cnt += 1 + else: + break + # Close the connection + connection.close() + + # Close the engine + engine.dispose() + else: + logger.info("Query Empty or None!") + return output, query + except Exception as e: + err = f"Error occurred: {format(e)}" + logger.info(err) + return None, err + finally: + connection.close() + engine.dispose() diff --git a/sidekick/logger.py b/sidekick/logger.py new file mode 100644 index 0000000..b7a9999 --- /dev/null +++ b/sidekick/logger.py @@ -0,0 +1,9 @@ +from loguru import logger +import sys +import toml +from pathlib import Path + +logger.remove() +base_path = (Path(__file__).parent / "../").resolve() +env_settings = toml.load(f"{base_path}/sidekick/configs/env.toml") +logger.add(sys.stderr, level=env_settings["LOGGING"]["LOG-LEVEL"]) diff --git a/sidekick/memory.py b/sidekick/memory.py new file mode 100644 index 0000000..4ca2176 --- /dev/null +++ b/sidekick/memory.py @@ -0,0 +1,80 @@ +import json +import re +from pathlib import Path +from typing import Dict, List, Tuple + + +# Reference: https://python.langchain.com/en/latest/modules/memory/examples/custom_memory.html +class EntityMemory: + def __init__(self, k, path: str = None): + self.k = k + self.track_history: List = [] + self.track_entity: List = [] + self.path = path + + def extract_entity(self, question: str, answer: str) -> Tuple[List, List]: + # Currently, anything tagged between below tags are extracted respectively, + # 1. From Input text: some key + # 2. From Output text: some key + # TODO Chat mode for auto extraction of entities + c_k = re.findall(r"(.+?)", question) + c_val = re.findall(r"(.+?)", answer) + return (c_k, c_val) + + def save_context(self, info: str, extract_context: bool = True) -> Dict: + # Construct dictionary to record history + # { + # 'Query': + # 'Answer': + # } + # Extract info from the supplied text + split_token = ";" + query = " ".join(info.partition(":")[2].split(split_token)[0].strip().split()) + response = " ".join(info.partition(":")[2].split(split_token)[1].partition(":")[2].strip().split()) + + # Check if entity extraction is enabled + # Add logic for entity extraction + extracted_entity = None + if extract_context: + _k, _v = self.extract_entity(query, response) + k_v = " ".join(_k) + c_v = ", ".join(_v) + extracted_entity = {k_v: c_v} + self.track_entity.append(extracted_entity) + + chat_history = {} + if query.strip() and "SELECT".lower() in response.lower(): + # Remove and tags from the query/response before persisting + query = ( + query.lower().replace("", "").replace("", "").replace("", "").replace("", "") + ) + response = ( + response.lower() + .replace("", "") + .replace("", "") + .replace("", "") + .replace("", "") + ) + chat_history = {"Query": query, "Answer": response, "Entity": extracted_entity} + self.track_history.append(chat_history) + else: + raise ValueError("Response not valid. Please try again.") + # persist the information for future use + res = {"history": self.track_history, "entity": self.track_entity} + + # Persist added information locally + if chat_history: + # TODO: Persist history for each user. This flow is currently only affects openai models. + with open(f"{self.path}/var/lib/tmp/.cache/history.jsonl", "a") as outfile: + json.dump(chat_history, outfile) + outfile.write("\n") + if extract_context: + # Update context.json file for tracking entities + content_file_path = f"{self.path}/var/lib/tmp/data/context.json" + context_dict = extracted_entity + if Path(content_file_path).exists(): + context_dict = json.load(open(content_file_path, "r")) + context_dict.update(extracted_entity) + with open(content_file_path, "w") as outfile: + json.dump(context_dict, outfile, indent=4, sort_keys=False) + return res diff --git a/sidekick/prompter.py b/sidekick/prompter.py new file mode 100644 index 0000000..895bb06 --- /dev/null +++ b/sidekick/prompter.py @@ -0,0 +1,749 @@ +import gc +import json +import os +from pathlib import Path +from typing import Optional + +import click +import openai +import pandas as pd +import sqlparse +import toml +import torch +from colorama import Back as B +from colorama import Fore as F +from colorama import Style +from pandasql import sqldf +from sidekick.db_config import DBConfig +from sidekick.logger import logger +from sidekick.memory import EntityMemory +from sidekick.query import SQLGenerator +from sidekick.schema_generator import generate_schema +from sidekick.utils import (REMOTE_LLMS, _execute_sql, check_vulnerability, + execute_query_pd, extract_table_names, + generate_suggestions, save_query, setup_dir) + +__version__ = "0.2.2" + +# Load the config file and initialize required paths +app_base_path = (Path(__file__).parent / "../").resolve() +# Below check is to handle the case when the app is running on the h2o.ai cloud or locally +default_base_path = app_base_path if os.path.isdir("./.sidekickvenv/bin/") else "/meta_data" +env_settings = toml.load(f"{app_base_path}/sidekick/configs/env.toml") +db_dialect = env_settings["DB-DIALECT"]["DB_TYPE"] +model_name = env_settings["MODEL_INFO"]["MODEL_NAME"] +h2o_remote_url = env_settings["MODEL_INFO"]["H2OGPTE_URL"] +h2o_key = env_settings["MODEL_INFO"]["H2OGPTE_API_TOKEN"] +# h2ogpt base model urls +h2ogpt_base_model_url = env_settings["MODEL_INFO"]["H2OGPT_URL"] +h2ogpt_base_model_key = env_settings["MODEL_INFO"]["H2OGPT_API_TOKEN"] + +self_correction_model = env_settings["MODEL_INFO"]["SELF_CORRECTION_MODEL"] +recommendation_model = env_settings["MODEL_INFO"]['RECOMMENDATION_MODEL'] + +os.environ["TOKENIZERS_PARALLELISM"] = "False" +# Env variables +if not os.getenv("H2OGPT_URL"): + os.environ["H2OGPT_URL"] = h2ogpt_base_model_url +if not os.getenv("H2OGPT_API_TOKEN"): + os.environ["H2OGPT_API_TOKEN"] = h2ogpt_base_model_key +if not os.getenv("H2OGPTE_URL"): + os.environ["H2OGPTE_URL"] = h2o_remote_url +if not os.getenv("H2OGPTE_API_TOKEN"): + os.environ["H2OGPTE_API_TOKEN"] = h2o_key +if not os.getenv("SELF_CORRECTION_MODEL"): + os.environ["SELF_CORRECTION_MODEL"] = self_correction_model +if not os.getenv("RECOMMENDATION_MODEL"): + os.environ["RECOMMENDATION_MODEL"] = recommendation_model + +def color(fore="", back="", text=None): + return f"{fore}{back}{text}{Style.RESET_ALL}" + + +msg = """Welcome to the SQL Sidekick!\nI am an AI assistant that helps you with SQL queries. +I can help you with the following:\n +1. Configure a local database(for schema validation and syntax checking): `sql-sidekick configure db-setup`.\n +2. Learn contextual query/answer pairs: `sql-sidekick learn add-samples`.\n +3. Simply add context: `sql-sidekick learn update-context`.\n +4. Ask a question: `sql-sidekick query`. +""" + + +@click.group(help=msg) +@click.version_option("-V", "--version", message=f"sql-sidekick - {__version__}") +def cli(): + # Book-keeping + setup_dir(default_base_path) + + +@cli.group("configure") +def configure(): + """Helps in configuring local database.""" + + +def enter_table_name(): + val = input(color(F.GREEN, "", "Would you like to create a table for the database? (y/n): ")) + return val + + +def enter_file_path(table: str): + val = input(color(F.GREEN, "", f"Please input the CSV file path to table {table} : ")) + return val + + +@configure.command("log", help="Adjust log settings") +@click.option("--set_level", "-l", help="Set log level (Default: INFO)") +def set_loglevel(set_level): + env_settings["LOGGING"]["LOG-LEVEL"] = set_level + # Update settings file for future use. + f = open(f"{default_base_path}/sidekick/configs/env.toml", "w") + toml.dump(env_settings, f) + f.close() + + +def _get_table_info(cache_path: str, table_name: str = None): + # Search for the file in the default current path, if not present ask user to enter the path + if Path(f"{cache_path}/{table_name}_table_info.jsonl").exists(): + table_info_path = f"{cache_path}/{table_name}_table_info.jsonl" # input schema in jsonl format + else: + # Search for table related meta data in tables.json + # TODO: In future, metadata could be pushed on to a Db. + if Path(f"{cache_path}/tables.json").exists(): + f = open(f"{cache_path}/tables.json", "r") + table_metadata = json.load(f) + current_meta = table_metadata[table_name] + if "schema_info_path" in current_meta: + table_info_path = current_meta["schema_info_path"] + if table_info_path is None: + # if table_info_path is None, generate default schema n set path + data_path = current_meta["samples_path"] + _, table_info_path = generate_schema(data_path=data_path, output_path=f"{cache_path}/{table_name}_table_info.jsonl") + table_metadata = {"schema_info_path": table_info_path} + with open(f"{cache_path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + return table_info_path + + +def update_table_info(cache_path: str, table_info_path: str = None, table_name: str = None): + if Path(f"{cache_path}/table_context.json").exists(): + f = open(f"{cache_path}/table_context.json", "r") + table_metadata = json.load(f) + if table_name: + table_metadata["tables_in_use"] = [table_name] + if table_info_path: + table_metadata["schema_info_path"] = table_info_path + else: + table_metadata = dict() + if table_name: + table_metadata["tables_in_use"] = [table_name] + if table_info_path: + table_metadata["schema_info_path"] = table_info_path + + table_metadata["data_table_map"] = {} + with open(f"{cache_path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + +# Experimental, might change in future. +def recommend_suggestions(cache_path: str, table_name: str, n_qs: int=10): + column_names = [] + if cache_path is None: + path = f"{default_base_path}/var/lib/tmp/data" + logger.debug(f"Retrieve meta information for table {table_name}") + cache_path = _get_table_info(path, table_name) + logger.debug(f"Updated table info path: {cache_path}") + if Path(cache_path).exists(): + with open(cache_path, "r") as in_file: + for line in in_file: + if line.strip(): + data = json.loads(line) + if "Column Name" in data and "Column Type" in data: + col_name = data["Column Name"] + column_names.append(col_name) + try: + r_url = _key = None + # First check for keys in env variables + logger.debug(f"Checking environment settings ...") + env_url = os.environ["H2OGPTE_URL"] + env_key = os.environ["H2OGPTE_API_TOKEN"] + if env_url and env_key: + r_url = env_url + _key = env_key + elif Path(f"{app_base_path}/sidekick/configs/env.toml").exists(): + # Reload .env info + logger.debug(f"Checking configuration file ...") + env_settings = toml.load(f"{app_base_path}/sidekick/configs/env.toml") + r_url = env_settings["MODEL_INFO"]["H2OGPTE_URL"] + _key = env_settings["MODEL_INFO"]["H2OGPTE_API_TOKEN"] + else: + raise Exception("Model url or key is missing.") + + result = generate_suggestions(remote_url=r_url, client_key=_key, column_names=column_names, n_qs=n_qs + ) + except Exception as e: + logger.error(f"Something went wrong, check the supplied credentials:\n{e}") + result = None + return result + + + +@configure.command( + "generate_schema", help=f"Helps generate default schema for the selected Database dialect: {db_dialect}" +) +@click.option("--data_path", default="data.csv", help="Enter the path of csv", type=str) +@click.option("--output_path", default="table_info.jsonl", help="Enter the path of generated schema in jsonl", type=str) +def generate_input_schema(data_path, output_path): + _, o_path = generate_schema(data_path=data_path, output_path=output_path) + click.echo(f"Schema generated for the input data at {o_path}") + + +@configure.command("db-setup", help=f"Enter information to configure {db_dialect} database locally") +@click.option("--db_name", "-n", default="querydb", help="Database name", prompt="Enter Database name") +@click.option("--hostname", "-h", default="localhost", help="Database hostname", prompt="Enter hostname name") +@click.option("--user_name", "-u", default=f"{db_dialect}", help="Database username", prompt="Enter username name") +@click.option( + "--password", + "-p", + default="abc", + hide_input=True, + help="Database password", + prompt="Enter password", +) +@click.option("--port", "-P", default=5432, help="Database port", prompt="Enter port (default 5432)") +@click.option("--table-info-path", "-t", help="Table info path", default=None) +def db_setup_cli(db_name: str, hostname: str, user_name: str, password: str, port: int, table_info_path: str): + db_setup( + db_name=db_name, + hostname=hostname, + user_name=user_name, + password=password, + port=port, + table_info_path=table_info_path, + table_samples_path=None, + table_name=None, + is_command=True, + ) + + +def db_setup( + db_name: str, + hostname: str, + user_name: str, + password: str, + port: int, + table_name: str, + table_info_path: Optional[str] = None, + table_schema: Optional[list] = None, + table_samples_path: Optional[str] = None, + add_sample: bool=True, + is_command: bool = False, + local_base_path: Optional[str] = None +): + """Helps setup local database. + Args: + db_name (str): Database name. + hostname (str): Hostname. + user_name (str): Username. + password (str): Password. + port (int): Port. + table_name (str): Table name. + table_info_path (str): Table info path. + table_schema (list): Table schema. + table_samples_path (str): Table samples path. + add_sample (bool): Add sample rows. + is_command (bool): Is command line interface. + local_base_path (str): Local base path. + """ + click.echo(f" Information supplied:\n {db_name}, {hostname}, {user_name}, {password}, {port}") + try: + res = err = None + # To-DO + # --- Need to remove the below keys from ENV toml --- # + # env_settings["TABLE_INFO"]["TABLE_INFO_PATH"] = table_info_path + # env_settings["TABLE_INFO"]["TABLE_SAMPLES_PATH"] = table_samples_path + + # Update settings file for future use. + # Check if the env.toml exists. + env_config_fname = f"{app_base_path}/sidekick/configs/env.toml" + if Path(env_config_fname).exists(): + env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] = hostname + env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] = user_name + env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] = password + env_settings["LOCAL_DB_CONFIG"]["PORT"] = port + env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] = db_name + f = open(env_config_fname, "w") + toml.dump(env_settings, f) + f.close() + base_path = local_base_path if local_base_path else default_base_path + path = f"{base_path}/var/lib/tmp/data" + # For current session + DBConfig.db_name = db_name + DBConfig.dialect = db_dialect + db_obj = DBConfig(hostname, user_name, password, port, base_path=base_path) + + # Create Database + if db_obj.dialect == "sqlite" and not os.path.isfile(f"{base_path}/db/sqlite/{db_name}.db"): + res, err = db_obj.create_db() + click.echo("Database created successfully!") + elif not db_obj.db_exists(): + res, err = db_obj.create_db() + click.echo("Database created successfully!") + else: + click.echo("Database already exists!") + + # Create Table in DB + val = enter_table_name() if is_command else "y" + while True: + if val.lower() != "y" and val.lower() != "n": + click.echo("In-correct values. Enter Yes(y) or no(n)") + val = enter_table_name() + else: + break + + if table_info_path is None and table_schema is None: + logger.debug(f"Retrieve meta information for table {table_name}") + table_info_path = _get_table_info(path, table_name) + logger.debug(f"Updated table info path: {table_info_path}") + + if val.lower() == "y" or val.lower() == "yes": + table_value = input("Enter table name: ") if is_command else table_name + click.echo(f"Table name: {table_value}") + # set table name + db_obj.table_name = table_value.lower().replace(" ", "_") + if table_schema: + res, err = db_obj.create_table(schema_info=table_schema) + else: + if table_info_path: + res, err = db_obj.create_table(schema_info_path=table_info_path) + + update_table_info(path, table_info_path, db_obj.table_name) + # Check if table exists; pending --> and doesn't have any rows + # Add rows to table + if db_obj.has_table(): + click.echo(f"Checked table {db_obj.table_name} exists in the DB.") + val = ( + input(color(F.GREEN, "", "Would you like to add few sample rows (at-least 3)? (y/n):")) + if is_command and not add_sample + else "y" + ) + val = "n" if not add_sample else "y" + if val.lower().strip() == "y" or val.lower().strip() == "yes": + val = input("Path to a CSV file to insert data from:") if is_command else table_samples_path + res, err = db_obj.add_samples(val) + else: + click.echo("Exiting...") + return None, err + else: + echo_msg = "Job done. Ask a question now!" + click.echo(echo_msg) + + if err is None: + click.echo( + f"Created a Database {db_name}. Inserted sample values from {table_samples_path} into table {table_name}, please ask questions!" + ) + return res, None + else: + return 0, err + except Exception as e: + error_msg = f"Error creating database. Check configuration parameters.\n: {e}" + click.echo(error_msg) + if not is_command: + return 0, error_msg + + +@cli.group("learn") +def learn(): + """Helps in learning and building memory.""" + + +def _add_context(entity_memory: EntityMemory): + _FORMAT = '''# Add input Query and Response \n\n +"Query": "";\n +"Response": """""" +''' + res = click.edit(_FORMAT.replace("\t", "")) + # Check if user has entered any value + if res: + try: + _ = entity_memory.save_context(res) + except ValueError as ve: + logger.info(f"Not a valid input. Try again") + + +@learn.command("add-samples", help="Helps add contextual query/answer pairs.") +def add_query_response(): + em = EntityMemory(k=5, path=default_base_path) + _add_context(em) + _more = "y" + while _more.lower() != "n" or _more.lower() != "no": + _more = click.prompt("Would you like to add more samples? (y/n)") + if _more.lower() == "y": + _add_context(em) + else: + break + + +@learn.command("update-context", help="Update context in memory for future use") +def update_context(): + """Helps learn context for generation.""" + # Book-keeping + setup_dir(default_base_path) + + context_dict = """{\n"": ""\n} + """ + content_file_path = f"{default_base_path}/var/lib/tmp/data/context.json" + context_str = context_dict + if Path(f"{default_base_path}/var/lib/tmp/data/context.json").exists(): + context_dict = json.load(open(content_file_path, "r")) + context_dict[""] = "" in context_dict: + del context_dict[""] + path = f"{default_base_path}/var/lib/tmp/data/" + with open(f"{path}/context.json", "w") as outfile: + json.dump(context_dict, outfile, indent=4, sort_keys=False) + else: + logger.debug("No content updated ...") + + +@cli.command() +@click.option("--question", "-q", help="Database name", prompt="Ask a question") +@click.option("--table-info-path", "-t", help="Table info path", default=None) +@click.option("--sample_qna_path", "-s", help="Samples path", default=None) +def query(question: str, table_info_path: str, sample_qna_path: str): + """Asks question and returns SQL.""" + ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=sample_qna_path, + table_name=None, + is_command=True, + ) + +def data_preview(table_name): + hostname = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] + user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] + password = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] + port = env_settings["LOCAL_DB_CONFIG"]["PORT"] + + db_obj = DBConfig(hostname, user_name, password, port, base_path=default_base_path + ) + if not db_obj.table_name: + db_obj.table_name = table_name + q_res = db_obj.data_preview(table_name) + # Convert result to data-frame + res = pd.DataFrame(q_res[0]) if q_res and q_res[0] else pd.DataFrame() + return res + +def ask( + question: str, + table_info_path: str, + sample_queries_path: str, + table_name: str, + model_name: str = "h2ogpt-sql-nsql-llama-2-7B", + db_dialect = "sqlite", + execute_db_dialect="sqlite", + is_regenerate: bool = False, + is_regen_with_options: bool = False, + is_command: bool = False, + debug_mode: bool = False, + execute_query: bool = True, + guardrails: bool = True, + self_correction: bool = True, + local_base_path = None, +): + """Ask a question and returns generate SQL. + Args: + question (str): Question to ask. + table_info_path (str): Path to table info. + sample_queries_path (str): Path to sample queries. + table_name (str): Table name. + model_name (str): Model name. + db_dialect (str): Database dialect. + execute_db_dialect (str): Database dialect to execute. + is_regenerate (bool): Regenerate SQL. + is_regen_with_options (bool): Regenerate SQL with options. + is_command (bool): Is command line interface. + execute_query (bool): Execute SQL. + debug_mode (bool): Debug mode. + self_correction (bool): Self correction. + local_base_path (str): Local base path. + + Returns: + list: List of results. + """ + + results = [] + res = err = alt_res = None # TODO - Need to handle errors if occurred + # Book-keeping + base_path = local_base_path if local_base_path else default_base_path + setup_dir(base_path) + + # Check if table exists + path = f"{base_path}/var/lib/tmp/data" + table_context_file = f"{path}/table_context.json" + table_context = json.load(open(table_context_file, "r")) if Path(table_context_file).exists() else {} + table_names = [] + + if not model_name: + model_name = env_settings["MODEL_INFO"]["MODEL_NAME"] + if table_name is not None: + table_names = [table_name.lower().replace(" ", "_")] + elif table_context and "tables_in_use" in table_context: + _tables = table_context["tables_in_use"] + table_names = [_t.lower().replace(" ", "_") for _t in _tables] + else: + # Ask for table name only when more than one table exists. + table_names = [click.prompt("Which table to use?")] + table_context["tables_in_use"] = [_t.lower().replace(" ", "_") for _t in table_names] + with open(f"{path}/table_context.json", "w") as outfile: + json.dump(table_context, outfile, indent=4, sort_keys=False) + logger.info(f"Table in use: {table_names}") + logger.info(f"SQL dialect for generation: {db_dialect}") + # Check if env.toml file exists + api_key = os.getenv("OPENAI_API_KEY", None) + if (model_name == 'gpt-3.5-turbo-0301' or model_name == 'gpt-3.5-turbo-1106') and api_key is None: + api_key = env_settings["MODEL_INFO"]["OPENAI_API_KEY"] + if api_key is None: + if is_command: + val = input( + color( + F.GREEN, "", "Looks like API key is not set, would you like to set OPENAI_API_KEY? (y/n):" + ) + ) + if val.lower() == "y": + api_key = input(color(F.GREEN, "", "Enter OPENAI_API_KEY :")) + + if api_key is None and is_command: + return ["Looks like API key is not set, please set OPENAI_API_KEY!"], err + + if os.getenv("OPENAI_API_KEY", None) is None: + os.environ["OPENAI_API_KEY"] = api_key + env_settings["MODEL_INFO"]["OPENAI_API_KEY"] = api_key + + # Update settings file for future use. + f = open(f"{app_base_path}/sidekick/configs/env.toml", "w") + toml.dump(env_settings, f) + f.close() + if model_name: + if 'gpt-3.5' in model_name or 'gpt-4' in model_name: + openai.api_key = api_key + logger.info(f"OpenAI key found.") + sql_g = None + try: + # Set context + logger.info("Setting context...") + logger.info(f"Question: {question}") + # Get updated info from env.toml + host_name = env_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] + user_name = env_settings["LOCAL_DB_CONFIG"]["USER_NAME"] + passwd = env_settings["LOCAL_DB_CONFIG"]["PASSWORD"] + db_name = env_settings["LOCAL_DB_CONFIG"]["DB_NAME"] + + if execute_db_dialect.lower() == "sqlite": + db_url = f"sqlite:///{base_path}/db/sqlite/{db_name}.db" + elif execute_db_dialect.lower() == "postgresql": + db_url = f"{execute_db_dialect}+psycopg2://{user_name}:{passwd}@{host_name}/{db_name}".format( + user_name, passwd, host_name, db_name + ) + elif execute_db_dialect.lower() == "databricks": + db_url = DBConfig._url + + if table_info_path is None: + table_info_path = _get_table_info(path, table_name) + logger.debug(f"Table info path: {table_info_path}") + + # Check if the model is present remotely + if model_name: + _remote_model = any(model_name.lower() in _m.lower() for _m in REMOTE_LLMS) + else: + _remote_model = False + sql_g = SQLGenerator( + db_url=db_url, + openai_key=api_key, + model_name=model_name, + job_path=base_path, + data_input_path=table_info_path, + sample_queries_path=sample_queries_path, + is_regenerate_with_options=is_regen_with_options, + is_regenerate=is_regenerate, + db_dialect=db_dialect, + debug_mode=debug_mode, + remote_model=_remote_model + ) + if model_name and "h2ogpt-sql" not in model_name and not _execute_sql(question): + sql_g._tasks = sql_g.generate_tasks(table_names, question) + results.extend(["I am thinking step by step: \n", sql_g._tasks, "\n"]) + click.echo(sql_g._tasks) + + updated_tasks = None + if sql_g._tasks is not None and is_command: + edit_val = click.prompt("Would you like to edit the tasks? (y/n)") + if edit_val.lower() == "y": + updated_tasks = click.edit(sql_g._tasks) + click.echo(f"Tasks:\n {updated_tasks}") + else: + click.echo("Skipping edit...") + if updated_tasks is not None: + sql_g._tasks = updated_tasks + + # The interface could also be used to simply execute user provided SQL + # Keyword: "Execute SQL: " + if ( + question is not None + and "select" in question.lower() + and (question.lower().startswith("question:") or question.lower().startswith("q:")) + ): + _q = question.lower().split("q:")[1].split("r:")[0].strip() + res = question.lower().split("r:")[1].strip() + question = _q + elif _execute_sql(question) and debug_mode: + logger.info("Executing user provided SQL without generation...") + res = question.strip().lower().split("execute sql:")[1].strip() + else: + logger.info("Computing user request ...") + _check_cond = question.strip().lower().split("execute sql:") + if len(_check_cond) > 1: + question = _check_cond[1].strip() + res, alt_res = sql_g.generate_sql(table_names, question, model_name=model_name) + logger.info(f"Input query: {question}") + logger.info(f"Generated response:\n\n{res}") + + if res is not None: + updated_sql = None + res_val = "e" + if is_command: + while res_val.lower() in ["e", "edit", "r", "regenerate"]: + res_val = click.prompt( + "Would you like to 'edit' or 'regenerate' the SQL? Use 'e' to edit or 'r' to regenerate. " + "To skip, enter 's' or 'skip'" + ) + if res_val.lower() == "e" or res_val.lower() == "edit": + updated_sql = click.edit(res) + click.echo(f"Updated SQL:\n {updated_sql}") + elif res_val.lower() == "r" or res_val.lower() == "regenerate": + click.echo("Attempting to regenerate...") + res, alt_res = sql_g.generate_sql( + table_names, question, model_name=model_name, _dialect=db_dialect + ) + res = res.replace("β€œ", '"').replace("”", '"') + [res := res.replace(s, '"') for s in "β€˜`’'" if s in res] + logger.info(f"Input query: {question}") + logger.info(f"Generated response:\n\n{res}") + pretty_sql = sqlparse.format(res, reindent=True, keyword_case="upper") + syntax_highlight = f"""``` sql\n{pretty_sql}\n```\n\n""" + results.extend([f"**Generated response for question,**\n{question}\n", syntax_highlight, "\n"]) + logger.info(f"Alternate responses:\n\n{alt_res}") + + exe_sql = "y" + if not execute_query: + if is_command: + exe_sql = click.prompt("Would you like to execute the generated SQL (y/n)?") + else: + exe_sql = "n" + + _val = updated_sql if updated_sql else res + if exe_sql.lower() == "y" or exe_sql.lower() == "yes": + # Before executing, check if known vulnerabilities exist in the generated SQL code. + if guardrails: + logger.info(f"Checking for vulnerabilities in the provided SQL: {_val}") + r, m = check_vulnerability(_val) if guardrails else (None, None) + q_res = m if r else None + + # For the time being, the default option is DB, but the user can be asked to select Database or pandas DF later. + option = "DB" # or DB + if option == "DB" and not r: + _val = _val.replace("β€œ", '"').replace("”", '"') + [_val := _val.replace(s, '"') for s in "β€˜`’'" if s in _val] + + q_res, err = DBConfig.execute_query(query=_val) + # Check for runtime/operational errors n attempt auto-correction + attempt = 0 + error_condition = lambda e: ('OperationalError'.lower() in e.lower() or 'OperationError'.lower() in e.lower() or 'Syntax error'.lower() in e.lower()) if e else False + if self_correction and error_condition(err): + logger.info("Attempting to auto-correct the query during runtime...") + while attempt !=3 and error_condition(err): + try: + logger.debug(f"Attempt: {attempt+1}") + _tmp = err.split("\n") + _err = _tmp[0].split("Error occurred:")[1] if len(_tmp) > 0 else None + env_url = os.environ["H2OGPTE_URL"] + env_key = os.environ["H2OGPTE_API_TOKEN"] + corr_sql = sql_g.self_correction(input_query=_val, error_msg=_err, remote_url=env_url, client_key=env_key) + q_res, err = DBConfig.execute_query(query=corr_sql) + if not 'Error occurred'.lower() in str(err).lower(): + err = None + attempt += 1 + except Exception as e: + logger.error(f"Something went wrong:\n{e}") + attempt += 1 + if m: + _t = "\n\n**Warning:**\n".join([str(q_res), m]) + q_res = _t + elif option == "pandas": + tables = extract_table_names(_val) + tables_path = dict() + if Path(f"{path}/table_context.json").exists(): + f = open(f"{path}/table_context.json", "r") + table_metadata = json.load(f) + for table in tables: + # Check if the local table_path exists in the cache + if table not in table_metadata["data_table_map"].keys(): + val = enter_file_path(table) + if not os.path.isfile(val): + click.echo("In-correct Path. Please enter again! Yes(y) or no(n)") + else: + tables_path[table] = val + table_metadata["data_table_map"][table] = val + break + else: + tables_path[table] = table_metadata["data_table_map"][table] + assert len(tables) == len(tables_path) + with open(f"{path}/table_context.json", "w") as outfile: + json.dump(table_metadata, outfile, indent=4, sort_keys=False) + try: + q_res = execute_query_pd(query=_val, tables_path=tables_path, n_rows=100) + except sqldf.PandaSQLException as e: + logger.error(f"Error in executing the query: {e}") + click.echo("Error in executing the query. Validate generated SQL and try again.") + click.echo("No result to display.") + + results.append("**Result:**\n") + if q_res: + # Check shape of the final result to avoid blowing up memory + # Logging a quick preview of the result + if isinstance(q_res, list) and len(q_res) > 10: + click.echo(f"Preview of the result:\n {pd.DataFrame(q_res).head(2)}") + else: + click.echo(f"The results are:\n {q_res}") + results.extend([str(q_res), "\n"]) + else: + click.echo(f"While executing query:\n {err}") + results.extend([str(err), "\n"]) + + save_sql = click.prompt("Would you like to save the generated SQL (y/n)?") if is_command else "n" + if save_sql.lower() == "y" or save_sql.lower() == "yes": + # Persist for future use + _val = updated_sql if updated_sql else res + save_query(base_path, query=question, response=_val) + else: + click.echo("Exiting...") + else: + results = ["I was not able to generate a response for the question. Please try re-phrasing or try again."] + alt_res, err = None, None + except (MemoryError, RuntimeError, AttributeError) as e: + logger.error(f"Something went wrong while generating response: {e}") + if sql_g: + del sql_g + gc.collect() + torch.cuda.empty_cache() + alt_res, err = None, e + results = ["Something went wrong while generating response. Please check the supplied API Keys and try again."] + return results, alt_res, err + + +if __name__ == "__main__": + cli() diff --git a/sidekick/query.py b/sidekick/query.py new file mode 100644 index 0000000..6a23600 --- /dev/null +++ b/sidekick/query.py @@ -0,0 +1,815 @@ +import gc +import json +import os +import random +import sys +from pathlib import Path + +import numpy as np +import openai +import sqlglot +import sqlparse +import torch +import torch.nn.functional as F +from llama_index import GPTVectorStoreIndex, ServiceContext, SQLDatabase +from llama_index.indices.struct_store import SQLContextContainerBuilder +from llama_index.indices.struct_store.sql import GPTSQLStructStoreIndex +from llama_index.llms import OpenAI as LOpenAI +from openai import OpenAI +from sidekick.configs.prompt_template import (DEBUGGING_PROMPT, + NSQL_QUERY_PROMPT, QUERY_PROMPT, + STARCODER2_PROMPT, TASK_PROMPT) +from sidekick.logger import logger +from sidekick.utils import (MODEL_CHOICE_MAP_EVAL_MODE, _check_file_info, + is_resource_low, load_causal_lm_model, + load_embedding_model, make_dir, re_rank, + read_sample_pairs, remove_duplicates, + semantic_search) +from sqlalchemy import create_engine + + +class SQLGenerator: + _instance = None + + def __new__( + cls, + db_url: str, + openai_key: str = None, + model_name="h2ogpt-sql-nsql-llama-2-7B-4bit", + data_input_path: str = "./table_info.jsonl", + sample_queries_path: str = "./samples.csv", + db_dialect = "sqlite", + job_path: str = "./", + device: str = "auto", + is_regenerate: bool = False, + is_regenerate_with_options: bool = False, + eval_mode = False, + remote_model = False, + debug_mode = False + ): + # TODO: If openai model then only tokenizer needs to be loaded. + offloading = is_resource_low(model_name) + n_gpus = torch.cuda.device_count() + # Initially load one model at a time if the user swapped the model dynamically when GPU = 1 + # If GPU > 1, load multiple models in memory separately on each device. + # TODO + # Support remote model loading as an option + + if ( + offloading + and is_regenerate_with_options + or (n_gpus == 1 and cls._instance and cls._instance.model_name and cls._instance.model_name != model_name) + ): + if ("gpt-3.5" not in cls._instance.model_name or "gpt-4" not in cls._instance.model_name) and ("gpt-3.5" not in model_name or "gpt-4" not in model_name) and cls._instance.models and cls._instance.models.get(cls._instance.model_name, None): + _name = cls._instance.model_name + del cls._instance.models[_name] + cls._instance.models[_name] = None + del cls._instance.tokenizers[_name] + cls._instance.tokenizers[_name] = None + + gc.collect() + torch.cuda.empty_cache() + logger.info(f"Low memory: {offloading}/ Model re-initialization: {is_regenerate_with_options}") + + if cls._instance is None or (cls._instance and hasattr(cls._instance, 'models') and not cls._instance.models.get(model_name, None)) or not hasattr(cls._instance, 'tokenizers'): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.current_temps = {} + # Load local models only wen remote models are not selected. + if not remote_model: + if not debug_mode: + # Currently. Debug mode is using remote model + # This could change in future. + logger.info(f"Loading local model: {model_name}") + cls._instance.models, cls._instance.tokenizers = load_causal_lm_model( + model_name, + cache_path=f"{job_path}/models/", + device=device, + off_load=offloading, + re_generate=is_regenerate_with_options, + ) + else: + cls._instance.models = {} + cls._instance.model_name = "h2ogpt-sql-sqlcoder2-4bit" if not model_name else model_name + model_embed_path = f"{job_path}/models/sentence_transformers" + cls._instance.current_temps[cls._instance.model_name] = 0.5 + device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device + if not debug_mode: + # Currently. Debug mode is using remote model + # This could change in future. + cls._instance.similarity_model = load_embedding_model(model_path=model_embed_path, device=device) + return cls._instance + + def __init__( + self, + db_url: str, + openai_key: str = None, + model_name="h2ogpt-sql-nsql-llama-2-7B-4bit", + data_input_path: str = "./table_info.jsonl", + sample_queries_path: str = "./samples.csv", + job_path: str = "./", + device: str = "cpu", + db_dialect = "sqlite", + is_regenerate: bool = False, + is_regenerate_with_options: bool = False, + eval_mode = False, + debug_mode = False, + remote_model = False + ): + self.db_url = db_url + self.engine = create_engine(db_url) if db_url else None + self.sql_database = SQLDatabase(self.engine) if self.engine else None + self.dialect = db_dialect + self.context_builder = None + self.data_input_path = _check_file_info(data_input_path) + self.sample_queries_path = sample_queries_path + self.path = job_path + self._data_info = None + self._tasks = None + self.model_name = model_name + self.openai_key = openai_key + self.content_queries = None + self.is_regenerate_with_options = is_regenerate_with_options + self.is_regenerate = is_regenerate + self.device = device + self.table_name = None, + self.eval_mode = eval_mode, + self.debug_mode = debug_mode, + self.remote_model = remote_model + self.openai_client = OpenAI() if openai_key else None + self.h2ogpt_client = None + + def clear(self): + del SQLGenerator._instance + SQLGenerator._instance = None + + def load_column_samples(self, tables: list): + # TODO: Maybe we add table name as a member variable + # Load column values if they exists + examples = {} + for _t in tables: + f_p = f"{self.path}/var/lib/tmp/data/{_t}_column_values.json" + if Path(f_p).exists(): + with open(f_p, "r") as f: + examples[_t] = json.load(f) + return examples + + def build_index(self, persist: bool = True): + # Below re-assignment of the OPENAI API key is weird but without that, it throws an error. + if self.openai_key: + os.environ["OPENAI_API_KEY"] = self.openai_key + openai.api_key = self.openai_key + + table_schema_index = self.context_builder.derive_index_from_context( + GPTVectorStoreIndex, + ) + if persist: + table_schema_index.save_to_disk(f"{self.path}/sql_index_check.json") + return table_schema_index + + def update_context_queries(self): + # Check if seed samples were provided + cache_path = f"{self.path}/var/lib/tmp/.cache/{self.table_name}/" + new_context_queries = [] + if self.sample_queries_path is not None and Path(self.sample_queries_path).exists(): + logger.info(f"Using QnA samples from path {self.sample_queries_path}") + new_context_queries = read_sample_pairs(self.sample_queries_path, "h2ogpt-sql") + # cache the samples for future use + make_dir(cache_path) + with open(f"{cache_path}/queries_cache.json", "w") as f: + json.dump(new_context_queries, f, indent=2) + elif self.sample_queries_path is None and Path(f"{cache_path}/queries_cache.json").exists(): + logger.info(f"Using samples from cache") + with open(f"{cache_path}/queries_cache.json", "r") as f: + new_context_queries = json.load(f) + # Read the history file and update the context queries + history_file = f"{self.path}/var/lib/tmp/.cache/{self.table_name}/history.jsonl" + try: + if Path(history_file).exists(): + with open(history_file, "r") as in_file: + for line in in_file: + # Format: + # """ + # # query: + # # answer: + # """ + if line.strip(): + data = json.loads(line) + if "Query" in data and "Answer" in data: + query = data["Query"] + response = data["Answer"] + _new_samples = f"""# query: {query}\n# answer: {response}""" + new_context_queries.append(_new_samples) + except ValueError as ve: + logger.error(f"Error in reading history file: {ve}") + pass + return new_context_queries + + def _query_tasks(self, question_str, data_info, sample_queries, table_name: list): + try: + context_file = f"{self.path}/var/lib/tmp/data/context.json" + additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {} + + system_prompt = TASK_PROMPT["system_prompt"] + user_prompt = TASK_PROMPT["user_prompt"].format( + _table_name=",".join(table_name), + _data_info=data_info, + _sample_queries=sample_queries, + _context=str(additional_context).lower(), + _question_str=question_str, + ) + # Role and content + query_txt = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] + + MODEL_CHOICE_MAP = MODEL_CHOICE_MAP_EVAL_MODE + m_name = MODEL_CHOICE_MAP.get(self.model_name) + + completion = self.openai_client.chat.completions.create( + model=m_name, + messages=query_txt, + max_tokens=512, + seed=42, + temperature=0.7 + ) + res = completion.choices[0].message.content + return res + except Exception as se: + _, ex_value, _ = sys.exc_info() + res = ex_value.statement if ex_value.statement else None + return res + + def self_correction(self, error_msg, input_query, remote_url, client_key): + try: + # Reference: Teaching Large Language Models to Self-Debug, https://arxiv.org/abs/2304.05128 + system_prompt = DEBUGGING_PROMPT["system_prompt"].format(dialect=self.dialect).strip() + user_prompt = DEBUGGING_PROMPT["user_prompt"].format(ex_traceback=error_msg, qry_txt=input_query).strip() + _response = [] + _res = input_query + self_correction_model = os.getenv("SELF_CORRECTION_MODEL", "h2oai/h2ogpt-4096-llama2-70b-chat") + if "h2ogpt-" in self_correction_model: + from h2ogpte import H2OGPTE + client = H2OGPTE(address=remote_url, api_key=client_key) + text_completion = client.answer_question( + system_prompt=system_prompt, + text_context_list=[], + question=user_prompt, + llm=self_correction_model) + _response = text_completion.content + elif 'gpt-3.5' in self_correction_model.lower() or 'gpt-4' in self_correction_model.lower(): + # Check if the API key is set, else inform user + _self_correction_model = MODEL_CHOICE_MAP_EVAL_MODE[self_correction_model.lower()] + query_msg = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] + completion = self.openai_client.chat.completions.create( + model=_self_correction_model, + messages=query_msg, + max_tokens=512, + seed=42, + temperature=0.7 + ) + _response = completion.choices[0].message.content + else: + raise ValueError(f"Invalid model name: {self_correction_model}") + + _response = _response.split("```sql") + _idx = [_response.index(_r) for _r in _response if _r.lower().strip().startswith("select")] + _res = _response[_idx[0]].split("```")[0].strip() + if "SELECT".lower() not in _res.lower(): + _res = input_query + result = sqlglot.transpile(_res, identify=True, write=self.dialect)[0] + return result + except Exception as se: + # Another exception occurred, return the original SQL + logger.info(f"Error in self correction: {se}") + result = _res + return result + + + def generate_response( + self, sql_index, input_prompt, attempt_fix_on_error: bool = True + ): + try: + _sql_index = sql_index.as_query_engine() + response = _sql_index.query(input_prompt) + res = response.metadata["sql_query"] + return res + except Exception as se: + # Take the SQL and make an attempt for correction + _, ex_value, ex_traceback = sys.exc_info() + qry_txt = ex_value.statement + if attempt_fix_on_error: + try: + # Attempt to heal with simple feedback + # Reference: Teaching Large Language Models to Self-Debug, https://arxiv.org/abs/2304.05128 + logger.info(f"Attempting to fix syntax error ...,\n {se}") + system_prompt = DEBUGGING_PROMPT["system_prompt"].format(dialect=self.dialect) + user_prompt = DEBUGGING_PROMPT["user_prompt"].format(ex_traceback=ex_traceback, qry_txt=qry_txt) + # Role and content + query_msg = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] + MODEL_CHOICE_MAP = MODEL_CHOICE_MAP_EVAL_MODE + m_name = MODEL_CHOICE_MAP.get(self.model_name, "gpt-3.5-turbo-1106") + + completion = self.openai_client.chat.completions.create( + model=m_name, + messages=query_msg, + max_tokens=512, + seed=42, + temperature=0.7 + ) + res = completion.choices[0].message.content + if "SELECT" not in res: + res = qry_txt + return res + except Exception as se: + # Another exception occurred, return the original SQL + res = qry_txt + return res + + def generate_tasks(self, table_names: list, input_question: str): + try: + # Step 1: Given a question, generate tasks to possibly answer the question and persist the result -> tasks.txt + # Step 2: Append task list to 'query_prompt_template', generate SQL code to answer the question and persist the result -> sql.txt + self.table_name = table_names[0] + context_queries: list = self.update_context_queries() + logger.info(f"Number of context queries found: {len(context_queries)}") + + # Remove duplicates from the context queries + m_path = f"{self.path}/models/sentence_transformers/" + duplicates_idx = remove_duplicates(context_queries, m_path) + updated_context = np.delete(np.array(context_queries), duplicates_idx).tolist() + + # Filter closest samples to the input question, threshold = 0.45 + filtered_context = ( + semantic_search( + input_question, + updated_context, + m_path, + threshold=0.9, + is_regenerate=True if (self.is_regenerate and not self.is_regenerate_with_options) else False, + ) + if len(updated_context) > 1 + else updated_context + ) + logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}") + _queries = "\n".join(filtered_context) + self.content_queries = _queries + + # data info + input_file = self.data_input_path + data_info = "" + with open(input_file, "r") as in_file: + for line in in_file: + if line.strip(): + data = json.loads(line) + data_info += "\n" + json.dumps(data) + self._data_info = data_info + task_list = self._query_tasks(input_question, data_info, _queries, table_names) + with open(f"{self.path}/var/lib/tmp/data/tasks.txt", "w") as f: + f.write(task_list) + return task_list + except Exception as se: + raise se + + def generate_sql( + self, + table_names: list, + input_question: str, + model_name: str = "h2ogpt-sql-nsql-llama-2-7B", + ): + # TODO: Update needed to support multiple tables + table_name = str(table_names[0].replace(" ", "_")).lower() + self.table_name = table_name + alternate_queries = [] + describe_keywords = ["describe table", "describe", "describe table schema", "describe data"] + enable_describe_qry = any([True for _dk in describe_keywords if _dk in input_question.lower()]) + + if input_question is not None and enable_describe_qry: + result = f"""SELECT "name" from PRAGMA_TABLE_INFO("{table_name}")""" + else: + context_file = f"{self.path}/var/lib/tmp/data/context.json" + additional_context = json.load(open(context_file, "r")) if Path(context_file).exists() else {} + table_context_dict = {table_name: str(additional_context).lower()} + context_queries = self.content_queries + self.context_builder = SQLContextContainerBuilder(self.sql_database, context_dict=table_context_dict) + + if model_name and "h2ogpt-sql" not in model_name: + _tasks = self.task_formatter(self._tasks) + + # TODO: The need to pass data info again could be eliminated if Task generation becomes more consistent and accurate. + query_str = QUERY_PROMPT.format( + dialect=self.dialect, + _data_info=self._data_info, + _question=input_question, + _table_name=table_names, + _sample_queries=context_queries, + _tasks=_tasks, + ) + + logger.debug(f"Query Text:\n {query_str}") + # Reference: https://github.com/jerryjliu/llama_index/issues/987 + model_choices = MODEL_CHOICE_MAP_EVAL_MODE + m_name = model_choices.get(model_name, "gpt-3.5-turbo-1106") + + llm_predictor_gpt3 = LOpenAI(temperature=0.7, model_name=m_name, max_tokens=512, seed=42) + service_context_gpt3 = ServiceContext.from_defaults( + llm=llm_predictor_gpt3, chunk_size_limit=512 + ) + + table_schema_index = self.build_index(persist=False) + self.context_builder.query_index_for_context(table_schema_index, query_str, store_context_str=True) + + index = GPTSQLStructStoreIndex( + [], sql_database=self.sql_database, table_name=table_names, service_context=service_context_gpt3 + ) + + result = self.generate_response(sql_index=index, input_prompt=query_str) + try: + # Check if `SQL` is formatted ---> ``` SQL_text ``` + if "```" in str(result): + res = ( + str(result) + .split("```", 1)[1] + .split(";", 1)[0] + .strip() + .replace("```", "") + .replace("sql\n", "") + .strip() + ) + else: + res = str(result).split("Explanation:", 1)[0].strip() + res = sqlglot.transpile(res, identify=True, write=self.dialect)[0] + result = res + except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e: + logger.info("We did the best we could, there might be still be some error:\n") + logger.info(f"Realized query so far:\n {res}") + else: + if self.h2ogpt_client is None: + # Check if env variable has info about remote hosting + remote_h2ogpt_base_url = os.environ.get("H2OGPT_URL", None) + if model_name == 'h2ogpt-sql-sqlcoder-34b-alpha': + remote_h2ogpt_base_url = f"{remote_h2ogpt_base_url}:5000/v1" + elif model_name == 'h2ogpt-sql-sqlcoder2': + remote_h2ogpt_base_url = f"{remote_h2ogpt_base_url}:5001/v1" + elif model_name == 'h2ogpt-sql-nsql-llama-2-7B': + remote_h2ogpt_base_url = f"{remote_h2ogpt_base_url}:5002/v1" + else: + remote_h2ogpt_base_url = None + remote_h2ogpt_key = os.environ.get("H2OGPT_API_TOKEN", None) + _api_key = remote_h2ogpt_key if remote_h2ogpt_key else "EMPTY" + if remote_h2ogpt_base_url: + client_args = dict(base_url=remote_h2ogpt_base_url, api_key=_api_key, timeout=20.0) + self.h2ogpt_client = OpenAI(**client_args) + + # TODO Update needed for multiple tables + columns_w_type = ( + self.context_builder.full_context_dict[table_name] + .split(":")[2] + .split(" and foreign keys")[0] + .strip().replace("(", "").replace(")", "") + ) + data_samples_list = self.load_column_samples(table_names) + + _context = { + "if patterns like 'current time' or 'now' occurs in question": "always use NOW() - INTERVAL", + "if patterns like 'total number', or 'List' occurs in question": "always use DISTINCT", + "detailed summary": "include min, avg, max for numeric columns", + "summary": "include min, avg, max for numeric columns", + } + + m_path = f"{self.path}/models/sentence_transformers/" + filtered_context = semantic_search( + model_obj=self.similarity_model, + input_q=input_question, + probable_qs=list(_context.keys()), + model_path=m_path, + threshold=0.90, + ) + logger.debug(f"Filter Context: {filtered_context}") + + contextual_context = [] + for _item in filtered_context: + _val = _context.get(_item, None) + if _val: + contextual_context.append(f"{_item}: {_val}") + + logger.info("Filtering Question/Query pairs ...") + context_queries: list = self.update_context_queries() + logger.info(f"Number of context queries found: {len(context_queries)}") + + # Remove duplicates from the context queries + m_path = f"{self.path}/models/sentence_transformers/" + # duplicates_idx = remove_duplicates(context_queries, m_path, similarity_model=self.similarity_model) + # updated_context = np.delete(np.array(context_queries), duplicates_idx).tolist() + + # Filter closest samples to the input question, threshold = 0.9 + filtered_context = ( + semantic_search( + input_q=input_question, + probable_qs=context_queries, + model_path=m_path, + model_obj=self.similarity_model, + threshold=0.9, + is_regenerate=True if (self.is_regenerate and not self.is_regenerate_with_options) else False, + ) + if len(context_queries) > 1 + else context_queries + ) + logger.info(f"Number of possible contextual queries to question: {len(filtered_context)}") + # If QnA pairs > 5, we keep top 5 for focused context + # Most relevant match is closest to the generation post re-ranking + _samples = filtered_context + _samples = re_rank(input_question, _samples) + if len(filtered_context) > 5: + _samples = filtered_context[0:5][::-1] + _samples = re_rank(input_question, _samples) + + qna_samples = "\n".join(_samples) + + contextual_context_val = ", ".join(contextual_context) + column_names = columns_w_type.strip().split(",") + clmn_names = [i.split(" ")[0].strip() for i in column_names if i] + clmn_types = [i.split(" ")[1].strip() for i in column_names if i] + clmn_info_map = dict(zip(clmn_names, clmn_types)) + + context_columns = [] + if len(_samples) > 2: + # Check for the columns in the QnA samples provided, if exists keep them + context_columns = [_c for _c in clmn_names if _c.lower().strip() in qna_samples.lower()] + + # To be safe, when we have more than 2 samples, we check for the column names in the question as well + first_pass = [_c for _c in clmn_names if _c.lower().strip() in input_question.lower().strip()] + _input = input_question.lower().split(" ") + for _c in clmn_names: + for _f in _c.lower().split("_"): + res = _f in _input + if res: + first_pass.append(_c) + context_columns = set(context_columns + first_pass) + if len(context_columns) > 0: + contextual_data_samples = [ + _d + for _cc in context_columns + for _d in data_samples_list[table_name] + if _cc.lower() in _d.lower() + ] + data_samples_list = contextual_data_samples + + if len(context_columns) > 0: + filtered_dict = {k: f"{clmn_info_map[k]}" for k in context_columns} + filtered_c_type = ", ".join([f"{k} {v}" for k, v in filtered_dict.items()]) + _column_info = filtered_c_type if len(context_columns) > 0 else [columns_w_type] + + logger.debug(f"Relevant sample column values: {data_samples_list}") + _table_name = ", ".join(table_names) + + query_prompt_format = STARCODER2_PROMPT + if "h2ogpt-sql-nsql-llama-2-7B" in model_name: + query_prompt_format = NSQL_QUERY_PROMPT + + query = query_prompt_format.format( + table_name=_table_name, + column_info=_column_info, + data_info_detailed=data_samples_list, + sample_queries=qna_samples, + context=contextual_context_val, + question_txt=input_question, + dialect=self.dialect + ) + + logger.debug(f"Query Text:\n {query}") + device_type = "cuda" if torch.cuda.is_available() else "cpu" + + # Check if the local models were selected + current_temperature = 0.5 + tokenizer = model = generated_tokens = None + if self.models and self.tokenizers and (model_name == "h2ogpt-sql-nsql-llama-2-7B-4bit" or model_name == "h2ogpt-sql-sqlcoder2-4bit" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha-4bit"): + tokenizer = self.tokenizers[model_name] + inputs = tokenizer([query], return_tensors="pt") + model = self.models[model_name] + current_temperature = self.current_temps.get(model_name, 0.5) + input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] + logger.info(f"Context length: {input_length}") + + # Handle limited context length + # Currently, conservative approach: remove column description from the prompt, if input_length > (2048-300) + # Others to try: + # 1. Move to a model with larger context length + # 2. Possibly use a different tokenizer for chunking + # 3. Maybe positional interpolation --> https://arxiv.org/abs/2306.15595 + if int(input_length) > 4000: + logger.info("Input length is greater than 1748, removing column description from the prompt") + query = query_prompt_format.format( + table_name=_table_name, + column_info=_column_info, + data_info_detailed="", + sample_queries=qna_samples, + context=contextual_context_val, + question_txt=input_question, + ) + logger.debug(f"Adjusted query Text:\n {query}") + inputs = tokenizer([query], return_tensors="pt") + input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1] + logger.info(f"Adjusted context length: {input_length}") + + possible_temp_gt_5 = [0.6, 0.75, 0.8, 0.9, 1.0] + possible_temp_lt_5 = [0.1, 0.2, 0.3, 0.4] + random_seed = random.randint(0, 50) + torch.manual_seed(random_seed) + random_temperature = np.random.choice(possible_temp_lt_5, 1)[0] if current_temperature >= 0.5 else np.random.choice(possible_temp_gt_5, 1)[0] + + if not self.is_regenerate_with_options and not self.is_regenerate: + # Greedy decoding, for fast response + # Reset temperature to 0.5 + current_temperature = 0.5 + if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B": + m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha") + query_txt = [{"role": "user", "content": query},] + logger.debug(f"Generation with default temperature : {current_temperature}") + completion = self.h2ogpt_client.with_options(max_retries=3).chat.completions.create( + model=m_name, + messages=query_txt, + max_tokens=512, + temperature=current_temperature, + stop="```", + seed=random_seed) + generated_tokens = completion.choices[0].message.content + logger.debug(f"Generated tokens: {generated_tokens}") + else: + if model: + model.eval() + output = model.generate( + **inputs.to(device_type), + max_new_tokens=512, + temperature=current_temperature, + output_scores=True, + do_sample=True, + return_dict_in_generate=True, + ) + + generated_tokens = output.sequences[:, input_length:][0] + elif self.is_regenerate and not self.is_regenerate_with_options: + # throttle temperature for different result + logger.info("Regeneration requested on previous query ...") + logger.debug(f"Selected temperature for fast regeneration : {random_temperature}") + if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B": + m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha") + query_txt = [{"role": "user", "content": query},] + completion = self.h2ogpt_client.with_options(max_retries=3).chat.completions.create( + model=m_name, + messages=query_txt, + max_tokens=512, + temperature=random_temperature, + stop="```", + seed=random_seed) + generated_tokens = completion.choices[0].message.content + else: + output = model.generate( + **inputs.to(device_type), + max_new_tokens=512, + temperature=random_temperature, + output_scores=True, + do_sample=True, + return_dict_in_generate=True, + ) + generated_tokens = output.sequences[:, input_length:][0] + self.current_temps[model_name] = random_temperature + logger.debug(f"Temperature saved: {self.current_temps[model_name]}") + else: + logger.info("Regeneration with options requested on previous query ...") + if model_name == "h2ogpt-sql-sqlcoder2" or model_name == "h2ogpt-sql-sqlcoder-34b-alpha" or model_name == "h2ogpt-sql-nsql-llama-2-7B": + logger.info("Generating diverse options, not enabled for remote models") + m_name = MODEL_CHOICE_MAP_EVAL_MODE.get(model_name, "h2ogpt-sql-sqlcoder-34b-alpha") + query_txt = [{"role": "user", "content": query},] + completion = self.h2ogpt_client.with_options(max_retries=3).chat.completions.create( + model=m_name, + messages=query_txt, + max_tokens=512, + temperature=random_temperature, + stop="```", + seed=random_seed) + generated_tokens = completion.choices[0].message.content + else: + # Diverse beam search decoding to explore more options + logger.debug(f"Selected temperature for diverse beam search: {random_temperature}") + output_re = model.generate( + **inputs.to(device_type), + max_new_tokens=512, + temperature=random_temperature, + top_k=5, + top_p=1.0, + num_beams=5, + num_beam_groups=5, + num_return_sequences=5, + output_scores=True, + do_sample=False, + diversity_penalty=2.0, + return_dict_in_generate=True, + ) + + transition_scores = model.compute_transition_scores( + output_re.sequences, output_re.scores, output_re.beam_indices, normalize_logits=False + ) + + # Create a boolean tensor where elements are True if the corresponding element in transition_scores is less than 0 + mask = transition_scores < 0 + # Sum the True values along axis 1 + counts = torch.sum(mask, dim=1) + output_length = inputs.input_ids.shape[1] + counts + length_penalty = model.generation_config.length_penalty + reconstructed_scores = transition_scores.sum(axis=1) / (output_length**length_penalty) + + # Converting logit scores to prob scores + probabilities_scores = F.softmax(reconstructed_scores, dim=-1) + out_idx = torch.argmax(probabilities_scores) + # Final output + output = output_re.sequences[out_idx] + generated_tokens = output[input_length:] + + logger.info(f"Generated options:\n") + prob_sorted_idxs = sorted( + range(len(probabilities_scores)), key=lambda k: probabilities_scores[k], reverse=True + ) + for idx, sorted_idx in enumerate(prob_sorted_idxs): + _out = output_re.sequences[sorted_idx] + res = tokenizer.decode(_out[input_length:], skip_special_tokens=True) + result = res.replace("table_name", _table_name).replace("```", "").strip() + if result.endswith(";"): + result = result.replace(";", "") + if "LIMIT".lower() not in result.lower(): + res = "SELECT " + result.strip() + " LIMIT 100;" + else: + res = "SELECT " + result.strip() + ";" + + pretty_sql = sqlparse.format(res, reindent=True, keyword_case="upper") + syntax_highlight = f"""``` sql\n{pretty_sql}\n```\n\n""" + alt_res = ( + f"Option {idx+1}: (_probability_: {probabilities_scores[sorted_idx]})\n{syntax_highlight}\n" + ) + alternate_queries.append(alt_res) + logger.info(f"Alternate options:\n{alt_res}") + + _res = generated_tokens + if not self.remote_model and tokenizer: + _res = tokenizer.decode(generated_tokens, skip_special_tokens=True) + # Below is a pre-caution in-case of an error in table name during generation + # COLLATE NOCASE is used to ignore case sensitivity, this might be specific to sqlite + _temp = _res.replace("table_name", table_name) if _res and _res != '' else None + res = _temp + if not _temp: + res = None + else: + if _temp.endswith("```"): + _temp = _temp.replace("```", "") + _temp = _temp.split("\n```")[0].strip() + # TODO Below should not happen, will have to check why its getting generated as part of response. + # Not sure, if its a vllm or prompt issue. + _temp = _temp.replace("/[/INST]", "").replace("[INST]", "").replace("[/INST]", "").strip() + if not _temp.lower().startswith('SELECT'.lower()): + _temp = "SELECT " + _temp.strip() + res = _temp + if "LIMIT".lower() not in _temp.lower(): + _temp = _temp.strip().replace(";", "") + " LIMIT 100;" + res = _temp + else: + res = _temp.strip() + ";" + + # Validate the generate SQL for parsing errors, along with dialect specific validation + # Note: Doesn't do well with handling date-time conversions + # e.g. + # sqlite: SELECT DATETIME(MAX(timestamp), '-5 minute') FROM demo WHERE isin_id = 'VM88109EGG92' + # postgres: SELECT MAX(timestamp) - INTERVAL '5 minutes' FROM demo where isin_id='VM88109EGG92' + # Reference ticket: https://github.com/tobymao/sqlglot/issues/2011 + result = res + try: + result = sqlglot.transpile(res, identify=True, write=self.dialect)[0] if res else None + except (sqlglot.errors.ParseError, ValueError, RuntimeError) as e: + _, ex_value, ex_traceback = sys.exc_info() + logger.info(f"Attempting to fix syntax error ...,\n {e}") + env_url = os.environ["H2OGPTE_URL"] + env_key = os.environ["H2OGPTE_API_TOKEN"] + try: + result = self.self_correction(input_query=res, error_msg=str(ex_traceback), remote_url=env_url, client_key=env_key) + except Exception as se: + # Another exception occurred, return the original SQL + logger.info(f"We did the best we could to fix syntactical error, there might be still be some issues:\n {se}") + logger.info(f"Realized query so far:\n {res}") + result = res + return result, alternate_queries + + def task_formatter(self, input_task: str): + # Generated format + """ + Tasks: + 1. Generate a SELECT query to display all columns of the {selected tables}. + 2. Infer the return type of the question as a description of the table schema. + 3. Final output: Return the table schema for the selected table. + """ + + # Converted format + """ + # 1. Generate a SELECT query to display all columns of the {selected tables}. + # 2. Infer the return type of the question as a description of the table schema. + """ + _res = input_task.split("\n") + start_index = 1 if "Tasks" in _res[0] else 0 + res = "\n".join([f"# {i}" for i in _res[start_index:]]) # Skip the first line + return res diff --git a/sidekick/schema_generator.py b/sidekick/schema_generator.py new file mode 100644 index 0000000..366c276 --- /dev/null +++ b/sidekick/schema_generator.py @@ -0,0 +1,62 @@ +import json +import random +import re +from pathlib import Path +from typing import Any, Optional +import pandas as pd + + +def generate_schema(output_path: str, data_path: Optional[str]=None, column_info: Optional[dict]=None): + """Generate the schema information for the given data file. + + Args: + output_path (str): The path to save the schema information. + data_path (Optional[str], optional): The path to the data file. Defaults to None. + column_info (Optional[dict], optional): A dictionary containing column information for a single table. + The keys are column names and the values are column types. Defaults to None. + + Returns: + schema_list (List[dict]): A list of dictionaries containing the schema information. + output_path (str): The path to the schema information. + """ + + # Load the data file + df = pd.read_csv(data_path) if data_path else None + # Extract the schema information + # column_info is a dictionary for a single table with column names as keys and column types as values + # TODO Extend column_info to support multiple tables. + schema = df.dtypes.to_dict() if df is not None else column_info + schema_list = [] + special_characters = {" ": "_", ":": "_", "/": "_", "-": "_", "(": "", ")": "", ".": "_"} + syntax_names = ["default"] + + for key, value in schema.items(): + new_key = "".join(special_characters[s] if s in special_characters.keys() else s for s in key) + if new_key.lower() in syntax_names: + new_key = new_key + "_col" + if new_key[0].isdigit(): + _temp = "".join((new_key[:0], "Digit_", new_key[1:])) + new_key = _temp + if value == "object": + value = "TEXT" + unique_values = df[key].dropna().unique().tolist() if df is not None else [] + if len(unique_values) > 10: + # Randomly sample 10 values + unique_values = random.sample(unique_values, k=10) + if not unique_values or (len(unique_values) > 0 and not bool(re.search(r"[A-Za-z]", unique_values[0]))): + schema_list.append({"Column Name": new_key, "Column Type": value}) + else: + schema_list.append({"Column Name": new_key, "Column Type": value, "Sample Values": unique_values}) + else: + value = "NUMERIC" + schema_list.append({"Column Name": new_key, "Column Type": value}) + + # Save the schema information to a JSONL format + if not Path(output_path).exists(): + f = Path(output_path) + f.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + for item in schema_list: + json.dump(item, f) + f.write("\n") + return schema_list, output_path diff --git a/sidekick/utils.py b/sidekick/utils.py new file mode 100644 index 0000000..8ea997d --- /dev/null +++ b/sidekick/utils.py @@ -0,0 +1,631 @@ +import errno +import glob +import json +import os +import re +from pathlib import Path +from typing import Optional + +import numpy as np +import pandas as pd +import torch +from accelerate import infer_auto_device_map, init_empty_weights +from h2ogpte import H2OGPTE +from huggingface_hub import snapshot_download +from openai import OpenAI +from pandasql import sqldf +from sentence_transformers import SentenceTransformer +from sidekick.configs.prompt_template import (GUARDRAIL_PROMPT, + RECOMMENDATION_PROMPT) +from sidekick.logger import logger +from sklearn.metrics.pairwise import cosine_similarity +from sqlglot import Dialects +from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) + +REMOTE_LLMS = ["h2ogpt-sql-sqlcoder-34b-alpha", "h2ogpt-sql-sqlcoder2", "h2ogpt-sql-nsql-llama-2-7B", + "gpt-3.5-turbo", "gpt-4-8k", "gpt-4-1106-preview-128k"] + +# clone of models from https://huggingface.co/models +# suffix `h2ogpt-sql-` is added to avoid conflict with the original models (we haven't done any changes to the original models yet) +MODEL_CHOICE_MAP_EVAL_MODE = { + "h2ogpt-sql-sqlcoder2-4bit": "defog/sqlcoder2", + "h2ogpt-sql-sqlcoder-34b-alpha-4bit": "defog/sqlcoder-34b-alpha", + "h2ogpt-sql-nsql-llama-2-7B-4bit": "NumbersStation/nsql-llama-2-7B", + "h2ogpt-sql-sqlcoder2": "defog/sqlcoder2", + "h2ogpt-sql-sqlcoder-34b-alpha": "defog/sqlcoder-34b-alpha", + "h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B", + "gpt-3.5-turbo": "gpt-3.5-turbo-1106", + "gpt-4-8k": "gpt-4", + "gpt-4-1106-preview-128k": "gpt-4-1106-preview" + +} + +MODEL_CHOICE_MAP_DEFAULT = { + # "h2ogpt-sql-sqlcoder2-4bit": "defog/sqlcoder2", + # "h2ogpt-sql-sqlcoder-34b-alpha-4bit": "defog/sqlcoder-34b-alpha", + # "h2ogpt-sql-nsql-llama-2-7B-4bit": "NumbersStation/nsql-llama-2-7B", + # "h2ogpt-sql-sqlcoder2": "defog/sqlcoder2", + "h2ogpt-sql-sqlcoder-34b-alpha": "defog/sqlcoder-34b-alpha", + "h2ogpt-sql-nsql-llama-2-7B": "NumbersStation/nsql-llama-2-7B" +} + +# Local models for now +MODEL_DEVICE_MAP = { + "h2ogpt-sql-sqlcoder2-4bit": 0, + "h2ogpt-sql-nsql-llama-2-7B-4bit": 1, +} + +TASK_CHOICE = { + "q_a": "Ask Questions", + "sqld": "Debugging", +} + +def list_models(): + """ List all the available models. """ + return list(MODEL_CHOICE_MAP_EVAL_MODE.keys()) + + +def list_db_dialects(): + """ List all the available SQL dialects.""" + return [_d.value for _d in Dialects.__members__.values() if _d != ''] + + +def generate_sentence_embeddings(model_path: str, x, batch_size: int = 32, device: Optional[str] = None): + # Reference: + # 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models + # 2. Evaluation result: https://www.sbert.net/_static/html/models_en_sentence_embeddings.html + # 3. Model Card: https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 + # 4. Reference: https://huggingface.co/spaces/mteb/leaderboard + # Maps sentence & paragraphs to a 384 dimensional dense vector space. + model_name_path = f"{model_path}/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/" + current_torch_home = os.environ.get("TORCH_HOME", "") + if Path(model_name_path).is_dir(): + is_empty = not any(Path(model_name_path).iterdir()) + if is_empty: + # Download n cache at the specified location + # https://public.ukp.informatik.tu-darmstadt.de/reimers/sentence-transformers/v0.2/all-MiniLM-L6-v2.zip + os.environ["TORCH_HOME"] = model_path + model_name_path = "sentence-transformers/all-MiniLM-L6-v2" + sentence_model = SentenceTransformer(model_name_path, device=device) + all_res = np.zeros(shape=(len(x), 0)) + res = sentence_model.encode(x, batch_size=batch_size, show_progress_bar=True) + all_res = np.hstack((all_res, res)) + del sentence_model + os.environ["TORCH_HOME"] = current_torch_home + return all_res + + +def load_embedding_model(model_path: str, device: str): + logger.debug(f"Loading embedding model from: {model_path}") + # Check if model exists if not download and cache + local_path = Path(f"{model_path}/models--BAAI--bge-base-en/snapshots/*/") + if not Path(local_path).is_dir(): + base_path = local_path.parents[2] + snapshot_download(repo_id="BAAI/bge-base-en", cache_dir=f"{base_path}/") + model_name_path = glob.glob(f"{model_path}/models--BAAI--bge-base-en/snapshots/*/")[0] + + sentence_model = SentenceTransformer(model_name_path, cache_folder=model_path, device=device) + if "cuda" not in device: + # Issue https://github.com/pytorch/pytorch/issues/69364 + # # In the initial experimentation, quantized model is generates slightly better results + logger.debug("Sentence embedding model is quantized ...") + model_obj = torch.quantization.quantize_dynamic(sentence_model, {torch.nn.Linear}, dtype=torch.qint8) + else: + model_obj = sentence_model + return model_obj + + +def generate_text_embeddings(model_path: str, x, model_obj=None, batch_size: int = 32, device: Optional[str] = "cpu"): + # Reference: + # 1. https://www.sbert.net/docs/pretrained_models.html#sentence-embedding-models + # Maps sentence & paragraphs to a 384 dimensional dense vector space. + if model_obj is None: + model_obj = load_embedding_model(model_path, device) + + _sentences = [["Represent this sentence for retrieving duplicate examples: ", _item] for _item in x] + + res = model_obj.encode(_sentences, normalize_embeddings=True) + return res + + +def re_rank(question: str, input_x: list): + # Currently using question length as final step to re-rank, might change in future + input_pqs = [_se.strip().lower().split("answer:")[0].strip() for _se in input_x[0:5]] + _dist = np.array([len(_in.split()) for _in in input_pqs]) + + query_len = len(question.lower().split()) + logger.debug(f"Question length: {query_len}") + sorted_ = np.argsort(abs(_dist - query_len))[::-1].tolist() + res = list(np.array(input_x)[sorted_]) + return res + + +def semantic_search( + input_q: str, + probable_qs: list, + model_path: str, + model_obj=None, + threshold: float = 0.80, + device="auto", + is_regenerate: bool = False, +): + # Only consider the questions, note: this might change in future. + _inq = ("# query: " + input_q).strip().lower() + logger.debug(f"Input questions: {_inq}") + _device = "cuda" if torch.cuda.is_available() else "cpu" if device == "auto" else device + question_embeddings = generate_text_embeddings(model_path, x=[_inq], model_obj=model_obj, device=_device) + + input_pqs = [_se.split("# answer")[0].strip().lower() for _se in probable_qs] + logger.debug(f"Probable context: {input_pqs}") + embeddings = generate_text_embeddings(model_path, x=input_pqs, model_obj=model_obj, device=_device) + res = {} + _scores = {} + for idx, _se in enumerate(embeddings): + similarities_score = cosine_similarity( + [_se.astype(float).tolist()], [question_embeddings.astype(float).tolist()[0]] + ) + logger.debug(f"Similarity score for: {input_pqs[idx]}: {similarities_score[0][0]}") + _scores[idx] = similarities_score[0][0] + if similarities_score[0][0] > threshold: + res[str(probable_qs[idx])] = similarities_score[0][0] + + # Get Top N Context Queries if user requested to regenerate regardless of scores + if len(res) == 0 and is_regenerate and len(_scores) > 0: + top_n = min(len(_scores), 2) + sorted_res = dict() + sorted_scores = sorted(_scores, key=_scores.get, reverse=True) + top_idxs = sorted_scores[:top_n] + for idx in top_idxs: + sorted_res[str(probable_qs[idx])] = similarities_score[0][0] + else: + sorted_res = sorted(res.items(), key=lambda x: x[1], reverse=True) + + logger.debug(f"Sorted context: {sorted_res}") + return list(dict(sorted_res).keys()) + + +def remove_duplicates( + input_x: list, model_path: str, similarity_model=None, threshold: float = 0.989, device: str = "cpu" +): + # Remove duplicates pairs + if input_x is None or len(input_x) < 2: + res = [] + else: + embeddings = generate_text_embeddings(model_path, x=input_x, model_obj=similarity_model, device=device) + similarity_scores = cosine_similarity(embeddings) + similar_indices = [(x, y) for (x, y) in np.argwhere(similarity_scores > threshold) if x != y] + + # Remove identical pairs e.g. [(0, 3), (3, 0)] -> [(0, 3)] + si = [similarity_scores[tpl] for tpl in similar_indices] + dup_pairs_idx = np.where(pd.Series(si).duplicated())[0].tolist() + remove_vals = [similar_indices[_itm] for _itm in dup_pairs_idx] + [similar_indices.remove(_itm) for _itm in remove_vals] + res = list(set([item[0] for item in similar_indices])) + return res + + +def save_query( + output_path: str, table_name: str, query, response, extracted_entity: Optional[dict] = "", is_invalid: bool = False +): + _response = response + # Probably need to find a better way to extra the info rather than depending on key phrases + if response and "Generated response for question,".lower() in response.lower(): + _response = ( + response.split("**Generated response for question,**")[1].split("``` sql")[1].split("```")[0].strip() + ) + chat_history = {"Query": query, "Answer": _response, "Entity": extracted_entity} + + # Persist history for contextual reference wrt to the table. + dir_name = ( + f"{output_path}/var/lib/tmp/.cache/{table_name}" + if not is_invalid + else f"{output_path}/var/lib/tmp/.cache/{table_name}/invalid" + ) + make_dir(dir_name) + with open(f"{dir_name}/history.jsonl", "a") as outfile: + json.dump(chat_history, outfile) + outfile.write("\n") + + +def setup_dir(base_path: str): + """ Setup the required directories.""" + dir_list = ["var/lib/tmp/data", "var/lib/tmp/jobs", "var/lib/tmp/.cache", "models", "db/sqlite"] + for _dl in dir_list: + p = Path(f"{base_path}/{_dl}") + if not p.is_dir(): + p.mkdir(parents=True, exist_ok=True) + + +def update_tables(json_file_path: str, new_data: dict): + # Check if the JSON file exists + if os.path.exists(json_file_path): + try: + # Read the existing content from the JSON file + with open(json_file_path, "r") as json_file: + existing_data = json.load(json_file) + logger.debug("Existing Data:", existing_data) + except Exception as e: + logger.debug(f"An error occurred while reading: {e}") + else: + existing_data = {} + logger.debug("JSON file doesn't exist. Creating a new one.") + + # Append new data to the existing content + existing_data.update(new_data) + + # Write the updated content back to the JSON file + try: + with open(json_file_path, "w") as json_file: + json.dump(existing_data, json_file, indent=4) + logger.debug("Data appended and file updated.") + except Exception as e: + logger.debug(f"An error occurred while writing: {e}") + + +def read_sample_pairs(input_path: str, model_name: str = "h2ogpt-sql"): + df = pd.read_csv(input_path) + df = df.dropna() + df = df.drop_duplicates() + df = df.reset_index(drop=True) + + # NSQL format + if "h2ogpt-sql" not in model_name: + # Open AI format + # Convert frame to below format + # [ + # "# query": "" + # "# answer": "" + # ] + res = df.apply(lambda row: f"# query: {row['query']}\n# answer: {row['answer']}", axis=1).to_list() + else: + # Convert frame to below format + # [ + # "Question": + # "Answer": + # + # ] + res = df.apply(lambda row: f"Question: {row['query']}\nAnswer:\n{row['answer']}", axis=1).to_list() + return res + + +def extract_table_names(query: str): + """ + Extracts table names from a SQL query. + + Parameters: + query (str): The SQL query to extract table names from. + + Returns: + list: A list of table names. + """ + table_names = re.findall(r"\bFROM\s+(\w+)", query, re.IGNORECASE) + table_names += re.findall(r"\bJOIN\s+(\w+)", query, re.IGNORECASE) + table_names += re.findall(r"\bUPDATE\s+(\w+)", query, re.IGNORECASE) + table_names += re.findall(r"\bINTO\s+(\w+)", query, re.IGNORECASE) + + # Below keywords may not be relevant for the project but adding for sake for completeness + table_names += re.findall(r"\bINSERT\s+INTO\s+(\w+)", query, re.IGNORECASE) + table_names += re.findall(r"\bDELETE\s+FROM\s+(\w+)", query, re.IGNORECASE) + return np.unique(table_names).tolist() + + +def execute_query_pd(query=None, tables_path=None, n_rows=100): + """ + Runs an SQL query on a pandas DataFrame. + + Parameters: + df (pandas DataFrame): The DataFrame to query. + query (str): The SQL query to execute. + + Returns: + pandas DataFrame: The result of the SQL query. + """ + for table in tables_path: + if not table in locals(): + # Update the local namespace with the table name, pandas object + locals()[f"{table}"] = pd.read_csv(tables_path[table]) + + res_df = sqldf(query, locals()) + return res_df + + +def get_table_keys(file_path: str, table_key: str): + res = [] + if not os.path.exists(file_path): + logger.debug(f"File '{file_path}' does not exist.") + return res, dict() + + with open(file_path, "r") as json_file: + data = json.load(json_file) + if isinstance(data, dict): + res = list(data.keys()) + if table_key: + return None, data[table_key] + else: + return res, data + + +def is_resource_low(model_name: str): + off_load = True + if not model_name: # If None, load all models + off_load = False + else: + n_gpus = torch.cuda.device_count() + logger.info(f"Number of GPUs: {n_gpus}") + device_index = 0 + if n_gpus > 1 and ("gpt-3.5" not in model_name or "gpt-4" not in model_name): + device_index = MODEL_DEVICE_MAP.get(model_name, 0) if model_name else 0 + logger.debug(f"Information on device: {device_index}") + free_in_GB = int(torch.cuda.mem_get_info(device_index)[0] / 1024**3) + total_memory = int(torch.cuda.get_device_properties(device_index).total_memory / 1024**3) + logger.info(f"Total Memory: {total_memory}GB") + logger.info(f"Free GPU memory: {free_in_GB}GB") + if (int(free_in_GB) - 2) >= int(0.3 * total_memory): + off_load = False + return off_load + + +def load_causal_lm_model( + model_type: str, + cache_path: str, + device: str, + load_in_8bit: bool = False, + load_in_4bit=True, + off_load: bool = False, + re_generate: bool = False, +): + try: + # Load h2oGPT.SQL model + # Index 0 is reserved for the default model + n_gpus = torch.cuda.device_count() + logger.info(f"Total GPUs: {n_gpus}") + models = {} + tokenizers = {} + + def _load_llm(model_type: str, device_index: int = 0, load_in_4bit=True): + device = {"": device_index} if torch.cuda.is_available() else "cpu" if device == "auto" else device + total_memory = int(torch.cuda.get_device_properties(device_index).total_memory / 1024**3) + free_in_GB = int(torch.cuda.mem_get_info(device_index)[0] / 1024**3) + logger.info(f"Free GPU memory: {free_in_GB}GB") + _load_in_8bit = load_in_8bit + model_name = model_type + logger.info(f"Loading model: {model_name} on device id: {device_index}") + logger.debug(f"Model cache: {cache_path}") + # 22GB (Least requirement on GPU) is a magic number for the current model size. + if off_load and re_generate and total_memory < 22: + # To prevent the system from crashing in-case memory runs low. + # TODO: Performance when offloading to CPU. + max_memory = {device_index: f"{4}GB"} + logger.info(f"Max Memory: {max_memory}, offloading to CPU") + with init_empty_weights(): + config = AutoConfig.from_pretrained(model_name, cache_dir=cache_path, offload_folder=cache_path) + # A blank model with desired config. + model = AutoModelForCausalLM.from_config(config) + device = infer_auto_device_map(model, max_memory=max_memory) + device["lm_head"] = 0 + _offload_state_dict = True + _llm_int8_enable_fp32_cpu_offload = True + _load_in_8bit = True + load_in_4bit = False + else: + max_memory = {device_index: f"{int(free_in_GB)-2}GB"} + _offload_state_dict = False + _llm_int8_enable_fp32_cpu_offload = False + + if _load_in_8bit and _offload_state_dict and not load_in_4bit: + _load_in_8bit = False if "cpu" in device else True + logger.debug( + f"Loading in 8 bit mode: {_load_in_8bit} with offloading state: {_llm_int8_enable_fp32_cpu_offload}" + ) + model = AutoModelForCausalLM.from_pretrained( + model_name, + cache_dir=cache_path, + device_map=device, + load_in_8bit=_load_in_8bit, + llm_int8_enable_fp32_cpu_offload=_llm_int8_enable_fp32_cpu_offload, + offload_state_dict=_offload_state_dict, + max_memory=max_memory, + offload_folder=f"{cache_path}/weights/", + ) + else: + logger.debug(f"Loading in 4 bit mode: {load_in_4bit} with device {device}") + nf4_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, + bnb_4bit_compute_dtype=torch.bfloat16, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_name, cache_dir=cache_path, device_map=device, quantization_config=nf4_config + ) + tokenizer = AutoTokenizer.from_pretrained( + model_name, cache_dir=cache_path, device_map=device, use_fast=True + ) + return model, tokenizer + + if not model_type: # if None, load all models + for device_index in range(n_gpus): + model_name = list(MODEL_CHOICE_MAP_EVAL_MODE.values())[device_index] + model, tokenizer = _load_llm(model_name, device_index) + _name = list(MODEL_CHOICE_MAP_EVAL_MODE.keys())[device_index] + models[_name] = model + tokenizers[_name] = tokenizer + else: + model_name = MODEL_CHOICE_MAP_EVAL_MODE[model_type] + d_index = MODEL_DEVICE_MAP[model_type] if n_gpus > 1 else 0 + model, tokenizer = _load_llm(model_name, d_index) + models[model_type] = model + tokenizers[model_type] = tokenizer + return models, tokenizers + except Exception as e: + logger.info(f"An error occurred while loading the model: {e}") + return None, None + + +def _check_file_info(file_path: str): + if file_path is not None and Path(file_path).exists(): + logger.info(f"Using information info from path {file_path}") + return file_path + else: + logger.info("Required info not found, provide a path for table information and try again") + raise FileNotFoundError(f"Table info not found at {file_path}") + + +def _execute_sql(query: str): + # Check forKeyword: "Execute SQL: " + + # TODO vulnerability check for possible SELECT SQL injection via source code. + _cond = False + _cond = re.findall(r"Execute SQL:\s+(.*)", query, re.IGNORECASE) + _temp_cond = query.strip().lower().split("execute sql:") + if len(_temp_cond) > 1: + _cond = True + return _cond + + +def make_dir(path: str): + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise Exception("Error reported while creating default directory path.") + + +def flatten_list(_list: list): + return [item for sublist in _list for item in sublist] + + +def check_vulnerability(input_query: str): + # Ignore: `SELECT "name" FROM PRAGMA_TABLE_INFO()` + # Common SQL injection patterns checklist + # Reference: + # 1. https://github.com/payloadbox/sql-injection-payload-list#generic-sql-injection-payloads + # 2. https://www.invicti.com/blog/web-security/sql-injection-cheat-sheet/#InlineSamples + sql_injection_patterns = [ + r"\b(UNION\s+ALL\s+SELECT|OR\s+\d+\s*=\s*\d+|1\s*=\s*1|--\s+)", + r"['\"]|(--|#)|' OR '1|' OR 1 -- -|\" OR \"\" = \"|\" OR 1 = 1 -- -|' OR '' = '|=0--+|OR 1=1|' OR 'x'='x'", + r'\b(SELECT\s+\*\s+FROM\s+\w+\s+WHERE\s+\w+\s*=\s*[\'"].*?[\'"]\s*;?\s*--)', + r'\b(INSERT\s+INTO\s+\w+\s+\(\s*\w+\s*,\s*\w+\s*\)\s+VALUES\s*\(\s*[\'"].*?[\'"]\s*,\s*[\'"].*?[\'"]\s*\)\s*;?\s*--)', + r"\b(DROP\s+TABLE|ALTER\s+TABLE|admin\'--)", # DROP TABLE/ALTER TABLE + r"\b(?:INSERT|UPDATE|DELETE|CREATE|ALTER|DROP)\b", + r"(?:'|\”|--|#|β€˜\s*OR\s*β€˜1|β€˜\s*OR\s*\d+\s*--\s*-|\"\s*OR\s*\"\" = \"|\"\s*OR\s*\d+\s*=\s*\d+\s*--\s*-|’\s*OR\s*''\s*=\s*β€˜|β€˜=β€˜|'=0--+|OR\s*\d+\s*=\s*\d+|β€˜\s*OR\s*β€˜x’=β€˜x’|AND\s*id\s*IS\s*NULL;\s*--|β€˜β€™β€™β€™β€™β€™β€™β€™β€™β€™β€™β€™β€™UNION\s*SELECT\s*β€˜\d+|%00|/\*.*?\*/|\|\||@\w+|@@\w+)", # Generic SQL injection patterns (Reference: https://github.com/payloadbox/sql-injection-payload-list#generic-sql-injection-payloads) + r"AND\s[01]|AND\s(true|false)|[01]-((true|false))", + r"\d+'\s*ORDER\s*BY\s*\d+--\+|\d+'\s*GROUP\s*BY\s*(\d+,)*\d+--\+|'\s*GROUP\s*BY\s*columnnames\s*having\s*1=1\s*--", + r"\bUNION\b\s+\b(?:ALL\s+)?\bSELECT\b\s+[A-Za-z0-9]+", # Union Based + r'\b(OR|AND|HAVING|AS|WHERE)\s+\d+=\d+(\s+AND\s+[\'"]\w+[\'"]\s*=\s*[\'"]\w+[\'"])?(\s*--|\s*#)?\b', + r"\b(?:RLIKE|IF)\s*\(\s*SELECT\s*\(\s*CASE\s*WHEN\s*\(\s*[\d=]+\s*\)\s*THEN\s*0x[0-9a-fA-F]+\s*ELSE\s*0x[0-9a-fA-F]+\s*END\s*\)\s*\)\s*AND\s*'\w+'=\w+\s*|\b%\s*AND\s*[\d=]+\s*AND\s*'\w+'=\w+\s*|and\s*\(\s*select\s*substring\s*\(\s*@@version,\d+,\d+\)\s*\)=\s*'[\w]'\b", + r"('|\")?\s*(or|\|\|)\s*sleep\(.*?\)\s*(\#|--)?\s*(;waitfor\s+delay\s+'[0-9:]+')?\s*;?(\s+AND\s+)?\s*\w+\s*=\s*\w+\s*", # Time Based + r"(ORDER BY \d+,\s*)*(ORDER BY \d+,?)*SLEEP\(\d+\),?(BENCHMARK\(\d+,\s*MD5\('[A-Z]'\)\),?)*\d*,?", # Additional generic UNION patterns + ] + + # Step 1: + # Check for SQL injection patterns in the SQL code + res = False + _msg = None + p_detected = [] + # Check if the supplied query starts with SELECT, only SELECT queries are allowed. + if not input_query.strip().lower().startswith("select"): + p_detected.append(['SQL keywords does not start with SELECT, only SELECT queries are allowed.']) + res = True + else: + for pattern in sql_injection_patterns: + matches = re.findall(pattern, input_query, re.IGNORECASE) + if matches: + if all(v == "'" for v in matches) or all(v == '' for v in matches): + matches = [] + else: + res = True + p_detected.append(matches) + _pd = set(flatten_list(p_detected)) + if res: + _detected_patterns = ", ".join([str(elem) for elem in _pd]) + _msg = f"The input question has malicious patterns, **{_detected_patterns}** that could lead to SQL Injection.\nSorry, I will not be able to provide an answer.\nPlease try rephrasing the question." + # Step 2: + # Step 2 is optional, if remote url is provided, check for SQL injection patterns in the generated SQL code via LLM + # Currently, only support only for models as an endpoints + logger.debug(f"Requesting additional scan using configured models") + remote_url = os.environ["H2OGPTE_URL"] + api_key = os.environ["H2OGPTE_API_TOKEN"] + + _system_prompt = GUARDRAIL_PROMPT["system_prompt"].strip() + output_schema = """{ + "type": "object", + "properties": { + "vulnerability": { + "type": "boolean" + }, + "explanation": { + "type": "string" + } + } + }""" + _user_prompt = GUARDRAIL_PROMPT["user_prompt"].format(query_txt=input_query, schema=output_schema).strip() + temp_result = None + try: + llm_scanner = os.getenv("VULNERABILITY_SCANNER", "h2oai/h2ogpt-4096-llama2-70b-chat") + if "h2ogpt-" in llm_scanner: + from h2ogpte import H2OGPTE + client = H2OGPTE(address=remote_url, api_key=api_key) + text_completion = client.answer_question( + system_prompt=_system_prompt, + text_context_list=[], + question=_user_prompt, + llm=llm_scanner) + generated_res = text_completion.content.split("\n\n")[0] + elif 'gpt-3.5' in llm_scanner.lower() or 'gpt-4' in llm_scanner.lower(): + # Check if the API key is set, else inform user + query_msg = [{"role": "system", "content": _system_prompt}, {"role": "user", "content": _user_prompt}] + _llm_scanner = MODEL_CHOICE_MAP_EVAL_MODE[llm_scanner.lower()] + openai_client = OpenAI() + completion = openai_client.chat.completions.create( + model=_llm_scanner, + messages=query_msg, + max_tokens=512, + seed=42, + temperature=0.7 + ) + generated_res = completion.choices[0].message.content + else: + raise ValueError(f"Invalid model name: {llm_scanner}") + + _res = generated_res.strip() + temp_result = json.loads(_res) if _res else None + except json.decoder.JSONDecodeError as je: + logger.error(f"Error while parsing the response: {je}") + temp_result = None + pass + + if temp_result: + vulnerable = temp_result['properties']['vulnerability'].get('value', None) + if vulnerable: + explanation_msg = temp_result['properties']['explanation'].get('value', None) + _t = " ".join([_msg, explanation_msg]) if explanation_msg and _msg else explanation_msg + _msg = _t + return res, _msg + + +def generate_suggestions(remote_url, client_key:str, column_names: list, n_qs: int=10): + results = [] + # Check if remote url contains h2o.ai/openai endpoints + if not remote_url or not client_key: + results = "Currently not supported or remote API key is missing." + else: + column_info = ','.join(column_names) + input_prompt = RECOMMENDATION_PROMPT.format(data_schema=column_info, n_questions=n_qs + ) + + recommender_model = os.getenv("RECOMMENDATION_MODEL", "h2oai/h2ogpt-4096-llama2-70b-chat") + client = H2OGPTE(address=remote_url, api_key=client_key) + text_completion = client.answer_question( + system_prompt=f"Act as a data analyst, based on below data schema help answer the question", + text_context_list=[], + question=input_prompt, + llm=recommender_model + ) + _res = text_completion.content.split("\n")[2:] + results = "\n".join(_res) + return results diff --git a/start.py b/start.py new file mode 100644 index 0000000..d9fc116 --- /dev/null +++ b/start.py @@ -0,0 +1,39 @@ +import os +import shlex +import subprocess +import time +from pathlib import Path + +from huggingface_hub import snapshot_download + + +def setup_dir(base_path: str): + dir_list = ["var/lib/tmp/data", "var/lib/tmp/jobs", "var/lib/tmp/.cache", "models", "db/sqlite"] + for _dl in dir_list: + p = Path(f"{base_path}/{_dl}") + if not p.is_dir(): + p.mkdir(parents=True, exist_ok=True) + + +print(f"Download models...") +base_path = (Path(__file__).parent).resolve() if os.path.isdir("./.sidekickvenv/bin/") else "/meta_data" +setup_dir(base_path) + +# Model 1: +print(f"Download model 1...") +snapshot_download(repo_id="NumbersStation/nsql-llama-2-7B", cache_dir=f"{base_path}/models/") +# Model 2: +print(f"Download model 2...") +snapshot_download(repo_id="defog/sqlcoder2", cache_dir=f"{base_path}/models/") +# Model 3: +print(f"Download model 3...") +snapshot_download(repo_id="defog/sqlcoder-34b-alpha", cache_dir=f"{base_path}/models/") + +print(f"Download embedding model...") +snapshot_download(repo_id="BAAI/bge-base-en", cache_dir=f"{base_path}/models/sentence_transformers/") + +print("Starting SQL-Sidekick.") +DAEMON_PATH = "./.sidekickvenv/bin/uvicorn" if os.path.isdir("./.sidekickvenv/bin/") else "/resources/venv/bin/uvicorn" + +cmd = f"{DAEMON_PATH} ui.app:main" +subprocess.check_output(shlex.split(cmd)) diff --git a/static/screenshot-01.png b/static/screenshot-01.png new file mode 100644 index 0000000..949dc1a Binary files /dev/null and b/static/screenshot-01.png differ diff --git a/static/screenshot-02.png b/static/screenshot-02.png new file mode 100644 index 0000000..5761b33 Binary files /dev/null and b/static/screenshot-02.png differ diff --git a/static/screenshot-03.png b/static/screenshot-03.png new file mode 100644 index 0000000..c2cf489 Binary files /dev/null and b/static/screenshot-03.png differ diff --git a/static/screenshot-04.png b/static/screenshot-04.png new file mode 100644 index 0000000..9d91570 Binary files /dev/null and b/static/screenshot-04.png differ diff --git a/static/screenshot-05.png b/static/screenshot-05.png new file mode 100644 index 0000000..2e5f00f Binary files /dev/null and b/static/screenshot-05.png differ diff --git a/tests/.env.example b/tests/.env.example new file mode 100644 index 0000000..5b770b0 --- /dev/null +++ b/tests/.env.example @@ -0,0 +1,11 @@ +DATABRICKS_HOST = "" +DATABRICKS_CLUSTER_ID = "" +DATABRICKS_TOKEN = "" + +OPENAI_API_KEY = "" # Needed only for openAI models +MODEL_NAME = "h2ogpt-sql-sqlcoder-34b-alpha" # Others: e.g. gpt-4, gpt-4-32k, text-davinci-003 +QUANT_TYPE = '4bit' +H2O_BASE_MODEL_URL = 'http://38.128.233.247' +H2O_BASE_MODEL_API_KEY = "" +RECOMMENDATION_MODEL_REMOTE_URL = "" +RECOMMENDATION_MODEL_API_KEY = "" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_databricks.py b/tests/test_databricks.py new file mode 100644 index 0000000..7d0d6e2 --- /dev/null +++ b/tests/test_databricks.py @@ -0,0 +1,76 @@ +import os +from pathlib import Path + +import pytest +from dotenv import load_dotenv +from sidekick.db_config import DBConfig +from sidekick.prompter import ask +from sidekick.utils import generate_text_embeddings, setup_dir +from sklearn.metrics.pairwise import cosine_similarity + +# Rename .env.example to .env and set the mentioned env variables before running the tests +load_dotenv() + +base_path = (Path(__file__).parent / "../").resolve() +base_path = "." +cache_path = f"{base_path}/var/lib/tmp" +setup_dir(base_path) + + +def compute_similarity_score(x1: str, x2:str): + m_path = f"{base_path}/models/sentence_transformers/" + _embedding1 = generate_text_embeddings(m_path, x=[x1, x2]) + _embedding2 = generate_text_embeddings(m_path, x=[x2]) + similarities_score = cosine_similarity(_embedding1.astype(float), _embedding2.astype(float)) + return similarities_score + +# Note: Needs Databricks cluster to be running for the below tests to execute successfully +# Check if below env variables are set +assert os.environ.get("DATABRICKS_HOST") is not None +assert os.environ.get("DATABRICKS_CLUSTER_ID") is not None +assert os.environ.get("DATABRICKS_TOKEN") is not None + +DBConfig.dialect = "databricks" +# Using a demo dataset from Databricks Catalog +config_args = { + "catalog": "samples", + "schema": "nyctaxi", + "cluster_id": os.environ.get("DATABRICKS_CLUSTER_ID") +} +table_name = "trips" # sample table related to NYC Taxi dataset +DBConfig.table_name = table_name +column_info, table_info_path = DBConfig.get_column_info(output_path=f"{cache_path}/{table_name}_table_info.jsonl", **config_args) + +def test_generation_execution_correctness(): + input_q = """Compute average trip distance""" + expected_sql = """SELECT AVG(trip_distance) AS avg_distance FROM trips""" + expected_value = '2.8528291993434256' + _runtime_value = _generated_sql = "" + + result, _, _ = ask( + question=input_q, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-sqlcoder-34b-alpha", + db_dialect="databricks", + execute_db_dialect="databricks", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=False, + guardrails=False, + self_correction=True + ) + + if result and len(result) > 0: + _generated_sql = str(result[1].split("``` sql\n")[1]) + if len(result) > 4: + _runtime_value = str(result[4]) + + _syntax_score = compute_similarity_score(expected_sql, _generated_sql) + _execution_val_score = compute_similarity_score(expected_value, _runtime_value) + assert _syntax_score[0][0] > 0.9 + assert _execution_val_score[0][0] > 0.95 diff --git a/tests/test_db_setup.py b/tests/test_db_setup.py new file mode 100644 index 0000000..cf864aa --- /dev/null +++ b/tests/test_db_setup.py @@ -0,0 +1,46 @@ +import os +from pathlib import Path + +import pytest +from sidekick.prompter import ask, db_setup +from sidekick.query import SQLGenerator +from sidekick.schema_generator import generate_schema +from sidekick.utils import setup_dir + +base_path = (Path(__file__).parent / "../").resolve() +base_path = "." +cache_path = f"{base_path}/var/lib/tmp" +setup_dir(base_path) + +HOST_NAME = "localhost" +USER_NAME = "sqlite" +PASSWORD = "abc" +DB_NAME = "query_test" +PORT = "5432" + +data_path = "./examples/demo/sleep_health_and_lifestyle_dataset.csv" +# Replace table_name as needed +table_name = "sleep_health_and_lifestyle" +_, table_info_path = generate_schema(data_path=data_path, output_path=f"{cache_path}/{table_name}_table_info.jsonl") +# Set DB and table to test +# Set add_sample=False if no need to add rows to the table (default: = True) +# Initialize DB +if Path(f"{base_path}/db/sqlite/{DB_NAME}.db").exists(): + os.remove(f"{base_path}/db/sqlite/{DB_NAME}.db") + + +# Currently testing sqlite setup +def test_db_setup(): + res, err = db_setup( + db_name=DB_NAME, + hostname=HOST_NAME, + user_name=USER_NAME, + password=PASSWORD, + port=PORT, + table_info_path=table_info_path, + table_samples_path=data_path, + table_name=table_name, + local_base_path=base_path + ) + assert err is None + assert res > 0 diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py new file mode 100644 index 0000000..efa9db7 --- /dev/null +++ b/tests/test_guardrails.py @@ -0,0 +1,173 @@ +import os +from pathlib import Path + +import pytest +from sidekick.prompter import ask, db_setup +from sidekick.query import SQLGenerator +from sidekick.schema_generator import generate_schema +from sidekick.utils import setup_dir + +base_path = (Path(__file__).parent / "../").resolve() +base_path = "." +cache_path = f"{base_path}/var/lib/tmp" +setup_dir(base_path) + +HOST_NAME = "localhost" +USER_NAME = "sqlite" +PASSWORD = "abc" +DB_NAME = "query_test" +PORT = "5432" + +data_path = "./examples/demo/sleep_health_and_lifestyle_dataset.csv" +# Replace table_name as needed +table_name = "sleep_health_and_lifestyle" +_, table_info_path = generate_schema(data_path=data_path, output_path=f"{cache_path}/{table_name}_table_info.jsonl") +# Set DB and table to test +# Set add_sample=False if no need to add rows to the table (default: = True) +# Initialize DB +if Path(f"{base_path}/db/sqlite/{DB_NAME}.db").exists(): + os.remove(f"{base_path}/db/sqlite/{DB_NAME}.db") + +_, err = db_setup( + db_name=DB_NAME, + hostname=HOST_NAME, + user_name=USER_NAME, + password=PASSWORD, + port=PORT, + table_info_path=table_info_path, + table_samples_path=data_path, + table_name=table_name, + local_base_path=base_path + ) + +def test_no_error(): + input_q = """SELECT * FROM sleep_health_and_lifestyle LIMIT 5;""" + + result = None + question = f"Execute SQL:\n{input_q}" + + result, _, _ = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-sqlcoder-34b-alpha", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + self_correction=False + ) + + assert 'malicious patterns' not in str(result) + +def test_blind_select_injection(): + input_q = """ + SELECT * FROM sleep_health_and_lifestyle WHERE person_id = 105 OR 1=1; + """ + + # 1. When guardrails are disabled + result = None + question = f"Execute SQL:\n{input_q}" + # Self correction is disabled + result, _, error = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-nsql-llama-2-7B", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + guardrails=False, + self_correction=False + ) + + assert 'malicious patterns' not in str(result) + + + # 2. When guardrails are enabled + result = None + question = f"Execute SQL:\n{input_q}" + + # Self correction is disabled + result, _, error = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-nsql-llama-2-7B", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + guardrails=True, + self_correction=False + ) + + assert 'malicious patterns' in str(result) + assert error is None + +def test_drop_injection(): + input_q = [""" + DROP sleep_health_and_lifestyle;--" + """, + """DROP sleep_health_and_lifestyle; + """, + """DROP sleep_health_and_lifestyle;#""", + """10; DROP TABLE sleep_health_and_lifestyle /*""" + ] + + + #1. Self correction is disabled + for _item in input_q: + result = None + question = f"Execute SQL:\n{_item}" + result, _, error = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-nsql-llama-2-7B", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + self_correction=False + ) + assert 'malicious patterns' in str(result) + assert 'SQL keywords does not start with SELECT' in str(result) + +@pytest.mark.parametrize("input_q, scanner", [("""SELECT * FROM sleep_health_and_lifestyle; DROP sleep_health_and_lifestyle""", "h2oai/h2ogpt-4096-llama2-70b-chat"), +("""SELECT * FROM sleep_health_and_lifestyle; DROP sleep_health_and_lifestyle""", "gpt-3.5-turbo")]) +def test_stacked_queries(input_q, scanner): + os.environ["VULNERABILITY_SCANNER"] = scanner + result = None + question = f"Execute SQL:\n{input_q}" + + result, _, _ = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-sqlcoder-34b-alpha", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + self_correction=False + ) + + assert 'malicious patterns' in str(result) + assert 'drop' in str(result) diff --git a/tests/test_llms.py b/tests/test_llms.py new file mode 100644 index 0000000..9fb0b1c --- /dev/null +++ b/tests/test_llms.py @@ -0,0 +1,175 @@ +import os +from pathlib import Path + +import pytest +from sidekick.prompter import ask, db_setup +from sidekick.query import SQLGenerator +from sidekick.schema_generator import generate_schema +from sidekick.utils import generate_text_embeddings, setup_dir +from sklearn.metrics.pairwise import cosine_similarity + +base_path = (Path(__file__).parent / "../").resolve() +base_path = "." +cache_path = f"{base_path}/var/lib/tmp" +setup_dir(base_path) + +HOST_NAME = "localhost" +USER_NAME = "sqlite" +PASSWORD = "abc" +DB_NAME = "query_test" +PORT = "5432" + +data_path = "./examples/demo/sleep_health_and_lifestyle_dataset.csv" +# Replace table_name as needed +table_name = "sleep_health_and_lifestyle" +_, table_info_path = generate_schema(data_path=data_path, output_path=f"{cache_path}/{table_name}_table_info.jsonl") +# Set DB and table to test +# Set add_sample=False if no need to add rows to the table (default: = True) +# Initialize DB +if Path(f"{base_path}/db/sqlite/{DB_NAME}.db").exists(): + os.remove(f"{base_path}/db/sqlite/{DB_NAME}.db") + +def compute_similarity_score(x1: str, x2:str): + m_path = f"{base_path}/models/sentence_transformers/" + _embedding1 = generate_text_embeddings(m_path, x=[x1, x2]) + _embedding2 = generate_text_embeddings(m_path, x=[x2]) + similarities_score = cosine_similarity(_embedding1.astype(float), _embedding2.astype(float)) + return similarities_score + + +_, err = db_setup( + db_name=DB_NAME, + hostname=HOST_NAME, + user_name=USER_NAME, + password=PASSWORD, + port=PORT, + table_info_path=table_info_path, + table_samples_path=data_path, + table_name=table_name, + local_base_path=base_path + ) + + +def test_basic_access_local(): + input_q = """What is the average sleep duration for each gender?""" + expected_1 = "Male" + expected_2 = "Female" + + result, _ar, error = ask( + question=input_q, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-nsql-llama-2-7B-4bit", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=False, + guardrails=False, + self_correction=True + ) + + assert expected_1 in str(result) + assert expected_2 in str(result) + + + +@pytest.mark.parametrize("question, model_name", [("What is the average sleep duration for each gender?", "h2ogpt-sql-sqlcoder-34b-alpha"), + ("What is the average sleep duration for each gender?", "h2ogpt-sql-nsql-llama-2-7B")] + ) +def test_basic_access_remote_models(question, model_name): + input_q = question + expected_1 = "Male" + expected_2 = "Female" + + result, _, _ = ask( + question=input_q, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name=model_name, + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=False, + guardrails=False, + self_correction=True + ) + + assert expected_1 in str(result) + assert expected_2 in str(result) + +@pytest.mark.parametrize("question, model_name", [ + ("What is the average sleep duration for each gender?", "gpt-3.5-turbo"), + ("What is the average sleep duration for each gender?", "gpt-4-8k"), + ("What is the average sleep duration for each gender?", "gpt-4-1106-preview-128k"), + ] + ) +def test_basic_access_openai_models(question, model_name): + input_q = question + expected_1 = "Male" + expected_2 = "Female" + + result, _ar, error = ask( + question=input_q, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name=model_name, + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=False, + guardrails=False, + self_correction=True + ) + + assert expected_1 in str(result) + assert expected_2 in str(result) + +@pytest.mark.parametrize("question, model_name", [("What are the most common occupations among individuals in the dataset?", "h2ogpt-sql-sqlcoder-34b-alpha"), +("What are the most common occupations among individuals in the dataset?", "h2ogpt-sql-nsql-llama-2-7B"), +("What are the most common occupations among individuals in the dataset?", "gpt-4-1106-preview-128k")] +) +def test_generation_execution_correctness(question, model_name): + input_q = question + expected_sql = """SELECT "Occupation", COUNT(*) AS "frequency" FROM "sleep_health_and_lifestyle" GROUP BY "Occupation" ORDER BY "frequency" DESC LIMIT 10""" + expected_value = str([('Nurse', 73), ('Doctor', 71), ('Engineer', 63), ('Lawyer', 47), ('Teacher', 40), ('Accountant', 37), ('Salesperson', 32), ('Software Engineer', 4), ('Scientist', 4), ('Sales Representative', 2), ('Manager', 1)]) + _runtime_value = _generated_sql = "" + result, _, _ = ask( + question=input_q, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name=model_name, + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=False, + guardrails=False, + self_correction=True + ) + + if result and len(result) > 0: + _idx = [result.index(_r) for _r in result if _r.startswith("``` sql")] + if _idx: + _generated_sql = str(result[_idx[0]].split("``` sql\n")[1]).replace("```", "").strip() + split_text = _generated_sql.split("\n") + _generated_sql = " ".join(split_text).strip() + result_idx = [result.index(_r) for _r in result if _r.startswith("**Result:**")] + if result_idx: + _runtime_value = str(result[result_idx[0]+1]) + + _syntax_score = compute_similarity_score(expected_sql, _generated_sql) + _execution_val_score = compute_similarity_score(expected_value, _runtime_value) + # compute similarity score + assert _syntax_score[0][0] > 0.9 + assert _execution_val_score[0][0] > 0.9 diff --git a/tests/test_self_correction.py b/tests/test_self_correction.py new file mode 100644 index 0000000..f2cd2b6 --- /dev/null +++ b/tests/test_self_correction.py @@ -0,0 +1,168 @@ +import os +from pathlib import Path + +import pytest +from sidekick.prompter import ask, db_setup +from sidekick.query import SQLGenerator +from sidekick.schema_generator import generate_schema +from sidekick.utils import setup_dir + +base_path = (Path(__file__).parent / "../").resolve() +base_path = "." +cache_path = f"{base_path}/var/lib/tmp" +setup_dir(base_path) + +HOST_NAME = "localhost" +USER_NAME = "sqlite" +PASSWORD = "abc" +DB_NAME = "query_test" +PORT = "5432" + +data_path = "./examples/demo/sleep_health_and_lifestyle_dataset.csv" +# Replace table_name as needed +table_name = "sleep_health_and_lifestyle" +_, table_info_path = generate_schema(data_path=data_path, output_path=f"{cache_path}/{table_name}_table_info.jsonl") +# Set DB and table to test +# Set add_sample=False if no need to add rows to the table (default: = True) +# Initialize DB +if Path(f"{base_path}/db/sqlite/{DB_NAME}.db").exists(): + os.remove(f"{base_path}/db/sqlite/{DB_NAME}.db") + + +_, err = db_setup( + db_name=DB_NAME, + hostname=HOST_NAME, + user_name=USER_NAME, + password=PASSWORD, + port=PORT, + table_info_path=table_info_path, + table_samples_path=data_path, + table_name=table_name, + local_base_path=base_path + ) + +def test_input1(): + input_q = """ + SELECT "age", AVERAGE("sleep_duration") AS "average_sleep_duration" group by "age" + FROM "sleep_health_and_lifestyle" + """ + + result = None + question = f"Execute SQL:\n{input_q}" + # 1. Self correction is disabled + _, _, error = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name='h2ogpt-sql-sqlcoder-34b-alpha', + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + self_correction=False + ) + assert error is not None + + # 2. Self correction enabled + result, _, error = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name='h2ogpt-sql-sqlcoder-34b-alpha', + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + self_correction=True + ) + assert result != input_q + assert error is None + +def test_input2(): + input_q = """ +SELECT "age_bucket", + AVG("sleep_duration") AS "average_sleep_duration" +FROM + (SELECT "age" AS "age_bucket" + FROM "sleep_health_and_lifestyle" + WHERE LOWER('Gender') LIKE CONCAT('%like%', '%Female,Male%') + AND LOWER('Occupation') LIKE '%Accountant,Doctor,Engineer,Lawyer,Manager,Nurse,Sales Representative,Salesperson,Scientist,Software Engineer,Teacher%' + AND LOWER('BMI_Category') LIKE '%Normal,Normal Weight,Obese,Overweight%' + AND LOWER('Blood_Pressure') LIKE '%115/75,%115/78,%117/76,%118/75,%118/76,%119/77%' + AND LOWER('Sleep_Disorder') LIKE '%Insomnia,Sleep Apnea%' + GROUP BY "age") AS "age_buckets" +JOIN "sleep_health_and_lifestyle" ON "age_buckets"."age_bucket" = "sleep_health_and_lifestyle"."age" +GROUP BY "age_buckets"."age_bucket" +ORDER BY "age_buckets"."age_bucket" NULLS LAST +LIMIT 100 +""" + + result = None + question = f"Execute SQL:\n{input_q}" + # 1. Self correction is disabled + _, _, error = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name=None, + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + self_correction=False + ) + + assert error is not None + + # 2. Self correction enabled + result, _, error = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-sqlcoder-34b-alpha", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + self_correction=True + ) + assert result != input_q + assert error is None + +@pytest.mark.parametrize("input_q, debugger", [("""SELECT CONCAT("age", " ", "heart_rate") AS "age_heart_rate" FROM "sleep_health_and_lifestyle" ORDER BY "age_heart_rate" DESC LIMIT 100 + """, "h2oai/h2ogpt-4096-llama2-70b-chat"), +("""SELECT CONCAT("age", " ", "heart_rate") AS "age_heart_rate" FROM "sleep_health_and_lifestyle" ORDER BY "age_heart_rate" DESC LIMIT 100 + """, "gpt-3.5-turbo")]) +def test_input3(input_q, debugger): + # There is no CONCAT function in SQLite + os.environ["SELF_CORRECTION_MODEL"] = debugger + question = f"Execute SQL:\n{input_q}" + res, _, error = ask( + question=question, + table_info_path=table_info_path, + sample_queries_path=None, + table_name=table_name, + is_command=False, + model_name="h2ogpt-sql-sqlcoder-34b-alpha", + is_regenerate=False, + is_regen_with_options=False, + execute_query=True, + local_base_path=base_path, + debug_mode=True, + guardrails=False, + self_correction=True + ) + assert error == None + assert res != None diff --git a/ui/app.py b/ui/app.py new file mode 100644 index 0000000..896cff2 --- /dev/null +++ b/ui/app.py @@ -0,0 +1,964 @@ +import concurrent.futures +import gc +import hashlib +import json +import os +from pathlib import Path +from typing import List, Optional + +import openai +import toml +import torch +from h2o_wave import Q, app, data, handle_on, main, on, ui +from h2o_wave.core import expando_to_dict +from sidekick.logger import logger as logging +from sidekick.prompter import (ask, data_preview, db_setup, + recommend_suggestions) +from sidekick.query import SQLGenerator +from sidekick.utils import (MODEL_CHOICE_MAP_DEFAULT, + MODEL_CHOICE_MAP_EVAL_MODE, TASK_CHOICE, + get_table_keys, save_query, setup_dir, + update_tables) + +# Load the config file and initialize required paths +app_base_path = (Path(__file__).parent / "../").resolve() +app_settings = toml.load(f"{app_base_path}/ui/ui_config.toml") +# Below check is to handle the case when the app is running on the h2o.ai cloud or locally +base_path = app_base_path if os.path.isdir("./.sidekickvenv/bin/") else "/meta_data" +tmp_path = f"{base_path}/var/lib/tmp" + +ui_title = app_settings["WAVE_UI"]["TITLE"] +ui_description = app_settings["WAVE_UI"]["SUB_TITLE"] + + +# env variables +env_settings = toml.load(f"{app_base_path}/sidekick/configs/env.toml") + +# Pre-initialize the models for faster response +def initialize_models(): + logging.info(f"Initializing models") + + _ = SQLGenerator( + db_url=None, + openai_key=None, + model_name="h2ogpt-sql-sqlcoder-34b-alpha", + job_path=base_path, + data_input_path="", + sample_queries_path="", + is_regenerate_with_options="", + is_regenerate="", + db_dialect="sqlite", + remote_model=True + ) + return + + +initialize_models() + + +async def user_variable(q: Q): + db_settings = toml.load(f"{app_base_path}/sidekick/configs/env.toml") + + q.client.db_dialect = db_settings["DB-DIALECT"]["DB_TYPE"] + q.client.host_name = db_settings["LOCAL_DB_CONFIG"]["HOST_NAME"] + q.client.user_name = db_settings["LOCAL_DB_CONFIG"]["USER_NAME"] + q.client.password = db_settings["LOCAL_DB_CONFIG"]["PASSWORD"] + q.client.db_name = db_settings["LOCAL_DB_CONFIG"]["DB_NAME"] + q.client.port = db_settings["LOCAL_DB_CONFIG"]["PORT"] + + tables, tables_info = get_table_keys(f"{tmp_path}/data/tables.json", None) + table_info = tables_info[tables[0]] if len(tables) > 0 else None + + q.client.table_info_path = table_info["schema_info_path"] if len(tables) > 0 else None + q.client.table_samples_path = table_info["samples_path"] if len(tables) > 0 else None + q.client.sample_qna_path = table_info["samples_qa"] if len(tables) > 0 else None + q.client.table_name = tables[0] if len(tables) > 0 else None + + q.client.model_choices = MODEL_CHOICE_MAP_DEFAULT + q.client.eval_mode = False + + +async def client_variable(q: Q): + q.client.query = None + +# Reference: https://wave.h2o.ai/docs/examples/table-markdown-pandas/ +def make_markdown_row(values): + return f"| {' | '.join([str(x) for x in values])} |" + + +def make_markdown_table(fields, rows): + return '\n'.join([ + make_markdown_row(fields), + make_markdown_row('-' * len(fields)), + '\n'.join([make_markdown_row(row) for row in rows]), + ]) + +# Use for page cards that should be removed when navigating away. +# For pages that should be always present on screen use q.page[key] = ... +def add_card(q, name, card) -> None: + q.client.cards.add(name) + q.page[name] = card + + +# Remove all the cards related to navigation. +def clear_cards(q, ignore: Optional[List[str]] = []) -> None: + if not q.client.cards: + return + + for name in q.client.cards.copy(): + if name not in ignore: + del q.page[name] + q.client.cards.remove(name) + + +@on("#chat") +async def chat(q: Q): + q.page["sidebar"].value = "#chat" + + if q.args.table_dropdown or q.args.model_choice_dropdown or q.args.task_dropdown: + # If a table/model is selected, the trigger causes refresh of the page + # so we update chat history with table name selection and return + # avoiding re-drawing. + q.page["chat_card"].data += [q.args.chatbot, False] + return + + if not q.args.chatbot: + clear_cards(q) # When routing, drop all the cards except of the main ones (header, sidebar, meta). + table_names = [] + tables, _ = get_table_keys(f"{tmp_path}/data/tables.json", None) + if len(tables) > 0: + with open(f"{tmp_path}/data/tables.json", "r") as json_file: + meta_data = json.load(json_file) + for table in tables: + original_name = meta_data[table].get("original_name", q.client.original_name) + table_names.append(ui.choice(table, f"{original_name}")) + + MODEL_CHOICE_MAP = q.client.model_choices + model_choices = [ui.choice(_key, _key) for _key in MODEL_CHOICE_MAP.keys()] + q.client.model_choice_dropdown = q.args.model_choice_dropdown = "h2ogpt-sql-sqlcoder-34b-alpha" + + task_choices = [ui.choice("q_a", "Ask Questions"), ui.choice("sqld", "Debugging")] + q.client.task_choice_dropdown = q.args.task_dropdown = "q_a" + + chat_card_command_items = [ + ui.command(name="download_accept", label="Download QnA history", icon="Download"), + ui.command(name="download_reject", label="Download in-correct QnA history", icon="Download"), + ] + + add_card( + q, + "background_card", + ui.form_card( + box="horizontal", + items=[ + ui.text("Ask Questions:"), + ui.inline(items=[ui.toggle(name="demo_mode", label="Demo", trigger=True)], justify="end"), + ], + ), + ), + + add_card( + q, + "select_tables", + ui.form_card( + box="vertical", + items=[ + ui.dropdown( + name="table_dropdown", + label="Table", + required=True, + choices=table_names, + value=q.client.table_name if q.client.table_name else None, + trigger=True, + ), + ui.dropdown( + name="model_choice_dropdown", + label="Model Choice", + required=True, + choices=model_choices, + value=q.client.model_choice_dropdown if q.client.model_choice_dropdown else None, + trigger=True, + ), + ], + ), + ), + add_card( + q, + "task_choice", + ui.form_card( + box="vertical", + items=[ + ui.dropdown( + name="task_dropdown", + label="Mode", + required=True, + choices=task_choices, + value=q.client.task_choice_dropdown if q.client.task_choice_dropdown else None, + trigger=True, + ) + ], + ), + ), + if not q.args.chatbot: + add_card( + q, + "chat_card", + ui.chatbot_card( + box=ui.box("vertical", height="500px"), + name="chatbot", + data=data(fields="content from_user", t="list", size=-50), + commands=chat_card_command_items, + events=["scroll"], + ), + ), + add_card( + q, + "additional_actions", + ui.form_card( + box=ui.box("vertical", height="120px"), + items=[ + ui.buttons( + [ + ui.button( + name="suggest", + icon="", + caption="Suggests possible questions one could start with", + label="Discover", + ), + ui.button( + name="regenerate", + icon="RepeatOne", + caption="Attempts regeneration of the last response", + label="Try Again", + primary=True, + ), + ui.button( + name="regenerate_with_options", + icon="RepeatAll", + caption="Regenerates with options", + label="Try Harder", + ), + ui.button( + name="save_conversation", + caption="Saves the conversation in the history for future reference to improve response", + label="Accept", + icon="Emoji2", + ), + ui.button( + name="save_rejected_conversation", + caption="Saves the disappointed conversation to improve response.", + label="Reject", + icon="EmojiDisappointed", + ), + ], + justify="center", + ) + ], + ), + ) + + if q.args.chatbot is None or q.args.chatbot.strip() == "": + _msg = """Welcome to the SQL Sidekick!\nI am an AI assistant, i am here to help you find answers to questions on structured data. +To get started, please select a table from the dropdown and ask your question. +One could start by learning about the dataset by asking questions like: +- Describe data. +- Preview.""" + q.args.chatbot = _msg + q.page["chat_card"].data += [q.args.chatbot, False] + logging.info(f"Chatbot response: {q.args.chatbot}") + + +@on("chatbot") +async def chatbot(q: Q): + q.page["sidebar"].value = "#chat" + + # Append user message. + q.page["chat_card"].data += [q.args.chatbot, True] + + if q.page["select_tables"].table_dropdown.value is None or q.client.table_name is None: + q.page["chat_card"].data += ["Please select a table to continue!", False] + return + + if ( + f"Table {q.client.table_dropdown} selected" in q.args.chatbot + or f"Model {q.client.model_choice_dropdown} selected" in q.args.chatbot + or f"mode selected" in q.args.chatbot + ): + return + + # Append bot response. + question = f"{q.args.chatbot}" + # Check on task choice. + if q.client.task_dropdown == "sqld" or q.args.task_dropdown == "sqld": + question = f"Execute SQL:\n{q.args.chatbot}" + q.args.debug_mode = True + logging.info(f"Question: {question}") + + # For regeneration, currently there are 2 modes + # 1. Quick fast approach by throttling the temperature + # 2. "Try harder mode (THM)" Slow approach by using the diverse beam search + llm_response = None + try: + if q.args.chatbot and ("preview data" in q.args.chatbot.lower() or "data preview" in q.args.chatbot.lower() or "preview" in q.args.chatbot.lower()) or f"preview {q.client.table_name}" in q.args.chatbot.lower(): + _response_df = data_preview(q.client.table_name) + # Format as markdown table + if not _response_df.empty: + df_markdown = make_markdown_table(fields = _response_df.columns.tolist(), rows=_response_df.values.tolist()) + n_cols = len(_response_df.columns) + llm_response = f"The selected dataset has total number of {n_cols} columns.\nBelow is quick preview:\n{df_markdown}" + elif q.args.chatbot and (q.args.chatbot.lower() == "recommend questions" or q.args.chatbot.lower() == "recommend qs"): + llm_response = recommend_suggestions(cache_path=q.client.table_info_path, table_name=q.client.table_name) + if not llm_response: + llm_response = "Something went wrong, check the API Keys provided." + logging.info(f"Recommended Questions:\n{llm_response}") + q.args.chatbot = None + elif q.args.chatbot and q.args.chatbot.lower() == "db setup": + llm_response, err = db_setup( + db_name=q.client.db_name, + hostname=q.client.host_name, + user_name=q.client.user_name, + password=q.client.password, + port=q.client.port, + table_info_path=q.client.table_info_path, + table_samples_path=q.client.table_samples_path, + table_name=q.client.table_name, + ) + elif q.args.chatbot and q.args.chatbot.lower() == "regenerate" or q.args.regenerate: + # Attempts to regenerate response on the last supplied query + logging.info(f"Attempt for regeneration") + if q.client.query is not None and q.client.query.strip() != "": + llm_response, alt_response, err = ask( + question=q.client.query, + sample_queries_path=q.client.sample_qna_path, + table_info_path=q.client.table_info_path, + table_name=q.client.table_name, + model_name=q.client.model_choice_dropdown, + is_regenerate=True, + is_regen_with_options=False + ) + llm_response = "\n".join(llm_response) + else: + llm_response = ( + "Sure, I can generate a new response for you. " + "However, in order to assist you effectively could you please provide me with your question?" + ) + elif q.args.chatbot and q.args.chatbot.lower() == "try harder" or q.args.regenerate_with_options: + # Attempts to regenerate response on the last supplied query + logging.info(f"Attempt for regeneration with options.") + if q.client.query is not None and q.client.query.strip() != "": + llm_response, alt_response, err = ask( + question=q.client.query, + sample_queries_path=q.client.sample_qna_path, + table_info_path=q.client.table_info_path, + table_name=q.client.table_name, + model_name=q.client.model_choice_dropdown, + is_regenerate=False, + is_regen_with_options=True + ) + response = "\n".join(llm_response) + if alt_response: + llm_response = response + "\n\n" + "**Alternate options:**\n" + "\n".join(alt_response) + logging.info(f"Regenerate response: {llm_response}") + else: + llm_response = response + else: + llm_response = ( + "Sure, I can generate a new response for you. " + "However, in order to assist you effectively could you please provide me with your question?" + ) + else: + q.client.query = question + with concurrent.futures.ThreadPoolExecutor() as pool: + llm_response, alt_response, err = await q.exec(pool, ask, question=q.client.query, + sample_queries_path=q.client.sample_qna_path, + table_info_path=q.client.table_info_path, + table_name=q.client.table_name, + model_name=q.client.model_choice_dropdown, + debug_mode=q.args.debug_mode + ) + llm_response = "\n".join(llm_response) + except (MemoryError, RuntimeError) as e: + logging.error(f"Something went wrong while generating response: {e}") + gc.collect() + torch.cuda.empty_cache() + llm_response = "Something went wrong, try executing the query again!" + q.client.llm_response = llm_response + q.page["chat_card"].data += [llm_response, False] + + +@on("submit_url_keys") +async def submit_url_keys(q: Q): + # Read/Update env variable + if q.args.textbox_remote_url: + env_settings["MODEL_INFO"]["H2OGPTE_URL"] = q.args.textbox_remote_url + os.environ["H2OGPTE_URL"] = q.args.textbox_remote_url + if q.args.textbox_h2o_api_key: + env_settings["MODEL_INFO"]["H2OGPTE_API_TOKEN"] = q.args.textbox_h2o_api_key + os.environ["H2OGPTE_API_TOKEN"] = q.args.textbox_h2o_api_key + if q.args.textbox_openai_api_key: + env_settings["MODEL_INFO"]["OPENAI_API_KEY"] = q.args.textbox_openai_api_key + os.environ["OPENAI_API_KEY"] = q.args.textbox_openai_api_key + + # Update settings file for future use. + f = open(f"{app_base_path}/sidekick/configs/env.toml", "w") + toml.dump(env_settings, f) + f.close() + q.page["settings"].success_add_bar.visible = True + await q.page.save() + return + +@on("file_upload") +async def fileupload(q: Q): + q.page["dataset"].error_bar.visible = False + q.page["dataset"].error_upload_bar.visible = False + q.page["dataset"].success_bar.visible = False + q.page["dataset"].progress_bar.visible = True + + await q.page.save() + + q.page["sidebar"].value = "#datasets" + usr_info_path = None + usr_samples_path = None + usr_sample_qa = None + + sample_data = q.args.sample_data + sample_schema = q.args.data_schema + sample_qa = q.args.sample_qa + + remove_chars = [" ", "-"] + org_table_name = usr_table_name = None + if ( + (q.args.table_name == "" or q.args.table_name is None) and sample_data and len(sample_data) > 0 + ): # User did not provide a table name, use the filename as table name + org_table_name = sample_data[0].split(".")[0].split("/")[-1] + logging.info(f"Using provided filename as table name: {org_table_name}") + q.args.table_name = org_table_name + if q.args.table_name: + org_table_name = q.args.table_name + usr_table_name = org_table_name.strip().lower() + for _c in remove_chars: + usr_table_name = usr_table_name.replace(_c, "_") + + logging.info(f"Upload initiated for {org_table_name} with scheme input: {sample_schema}") + if sample_data is None: + q.page["dataset"].error_bar.visible = True + q.page["dataset"].error_upload_bar.visible = False + q.page["dataset"].progress_bar.visible = False + else: + if sample_data: + usr_samples_path = await q.site.download( + sample_data[0], f"{tmp_path}/jobs/{usr_table_name}_table_samples.csv" + ) + if sample_schema: + usr_info_path = await q.site.download( + sample_schema[0], f"{tmp_path}/jobs/{usr_table_name}_table_info.jsonl" + ) + if sample_qa: + usr_sample_qa = await q.site.download(sample_qa[0], f"{tmp_path}/jobs/{usr_table_name}_sample_qa.csv") + + q.page["dataset"].error_bar.visible = False + + table_metadata = dict() + table_metadata[usr_table_name] = { + "original_name": org_table_name, + "schema_info_path": usr_info_path, + "samples_path": usr_samples_path, + "samples_qa": usr_sample_qa, + } + try: + logging.info(f"Table metadata: {table_metadata}") + update_tables(f"{tmp_path}/data/tables.json", table_metadata) + + q.client.table_name = usr_table_name + q.client.table_samples_path = usr_samples_path + q.client.table_info_path = usr_info_path + q.client.sample_qna_path = usr_sample_qa + + n_rows, db_resp = db_setup( + db_name=q.client.db_name, + hostname=q.client.host_name, + user_name=q.client.user_name, + password=q.client.password, + port=q.client.port, + table_info_path=q.client.table_info_path, + table_samples_path=q.client.table_samples_path, + table_name=q.client.table_name, + ) + logging.info(f"DB updates: \n {db_resp}") + if "error" in str(db_resp).lower(): + q.page["dataset"].error_upload_bar.visible = True + q.page["dataset"].error_bar.visible = False + q.page["dataset"].progress_bar.visible = False + else: + q.page["dataset"].progress_bar.visible = False + q.page["dataset"].success_bar.text = f"Data successfully uploaded, it has {n_rows:,} rows!" + q.page["dataset"].success_bar.visible = True + except Exception as e: + logging.error(f"Something went wrong while uploading the dataset: {e}") + q.page["dataset"].error_upload_bar.visible = True + q.page["dataset"].error_bar.visible = False + q.page["dataset"].progress_bar.visible = False + return + +@on("#settings") +async def on_settings(q: Q): + q.page["sidebar"].value = "#settings" + clear_cards(q) # When routing, drop all the cards except of the main ones (header, sidebar, meta). + add_card(q, "settings_header", ui.form_card(box="horizontal", title="Configure", items=[])) + + toggle_state = q.client.eval_mode if q.client.eval_mode else False + add_card( + q, + "settings", + ui.form_card( + box="vertical", + items=[ + ui.textbox(name='textbox_remote_url', label='Recommendation Remote URL', + value='https://playground.h2ogpte.h2o.ai', prefix='https://'), + ui.textbox(name='textbox_h2o_api_key', label='H2O API Key', + value='', password=True), + ui.textbox(name='textbox_openai_api_key', label='OpenAI API Key', + value='', password=True), + ui.button(name="submit_url_keys", label="Add", primary=True), + ui.message_bar( + name="error_add_bar", + type="error", + text="Check the credentials provided.", + visible=False, + ), + ui.message_bar( + name="success_add_bar", + type="success", + text=f"Information added successfully!", + visible=False, + ), + ui.separator(label='Others'), + ui.toggle(name='eval_mode', label='Eval Mode', value=toggle_state)] + )) + + +@on("#datasets") +async def datasets(q: Q): + q.page["sidebar"].value = "#datasets" + clear_cards(q) # When routing, drop all the cards except of the main ones (header, sidebar, meta). + add_card(q, "data_header", ui.form_card(box="horizontal", title="Input Data", items=[])) + + add_card( + q, + "dataset", + ui.form_card( + box="vertical", + items=[ + ui.message_bar( + name="error_bar", + type="error", + text="Please input table name and upload data to get started!", + visible=False, + ), + ui.message_bar( + name="error_upload_bar", + type="error", + text="Upload failed; something went wrong. Please check the dataset name/column name for special characters and try again!", + visible=False, + ), + ui.message_bar( + name="success_bar", + type="success", + text=f"Data successfully uploaded!", + visible=False, + ), + ui.file_upload( + name="sample_data", + label="Dataset", + compact=True, + multiple=False, + file_extensions=["csv"], + required=True, + max_file_size=5000, # Specified in MB. + tooltip="Upload data to ask questions (currently only .CSV is supported)", + ), + ui.separator(label="Optional"), + ui.textbox( + name="table_name", + label="Table Name", + tooltip="Name of the table to be created, by default data filename is used!", + ), + ui.file_upload( + name="data_schema", + label="Data Schema", + multiple=False, + compact=True, + file_extensions=["jsonl"], + max_file_size=5000, # Specified in MB. + tooltip="The schema input summarizing the uploaded structured table, formats allowed are JSONL. If not provided, default schema will be inferred from the data", + ), + ui.file_upload( + name="sample_qa", + label="Sample Q&A", + multiple=False, + compact=True, + file_extensions=["csv"], + required=False, + max_file_size=5000, # Specified in MB. + tooltip="Sample QnA pairs to improve contextual generation (currently only .CSV is supported)", + ), + ui.progress( + name="progress_bar", width="100%", label="Uploading datasets and creating tables!", visible=False + ), + ui.button(name="file_upload", label="Upload", primary=True), + ], + ), + ) + + +@on("#documentation") +async def about(q: Q): + q.page["meta"].script = ui.inline_script(f"window.open('https://github.com/h2oai/sql-sidekick');") + + +@on("#support") +async def on_support(q: Q): + q.page["meta"].script = ui.inline_script(f"window.open('https://github.com/h2oai/sql-sidekick/issues');") + + +@on("submit_table") +async def submit_table(q: Q): + table_key = q.args.table_dropdown + if table_key: + table_name = table_key.lower().replace(" ", "_") + _, table_info = get_table_keys(f"{tmp_path}/data/tables.json", table_name) + + q.client.table_info_path = table_info["schema_info_path"] + q.client.table_samples_path = table_info["samples_path"] + q.client.sample_qna_path = table_info["samples_qa"] + q.client.table_name = table_key.replace(" ", "_") + q.client.original_name = table_info["original_name"] + q.page["select_tables"].table_dropdown.value = table_name + else: + q.page["select_tables"].table_dropdown.value = q.client.table_name + await q.page.save() + + +async def init(q: Q) -> None: + q.client.timezone = "UTC" + q.args.demo_mode = False + q.app.toml = toml.load("app.toml") + + username, profile_pic = q.auth.username, q.app.persona_path + q.page["meta"] = ui.meta_card( + script=heap_analytics( + userid=q.auth.subject, + event_properties=f"{{" + f"version: '{q.app.toml['App']['Version']}', " + f"product: '{q.app.toml['App']['Title']}'" + f"}}", + ), + box="", + layouts=[ + ui.layout( + breakpoint="xs", + min_height="100vh", + zones=[ + ui.zone( + "main", + size="1", + direction=ui.ZoneDirection.ROW, + zones=[ + ui.zone("sidebar", size="250px"), + ui.zone( + "body", + zones=[ + ui.zone( + "content", + zones=[ + # Specify various zones and use the one that is currently needed. Empty zones are ignored. + ui.zone("horizontal", direction=ui.ZoneDirection.ROW), + ui.zone("vertical"), + ui.zone( + "grid", direction=ui.ZoneDirection.ROW, wrap="stretch", justify="center" + ), + ], + ), + ], + ), + ], + ) + ], + ) + ], + ) + q.page["sidebar"] = ui.nav_card( + box="sidebar", + color="primary", + title="QnA Assistant", + subtitle="Get answers to your questions.", + value=f'#{q.args["#"]}' if q.args["#"] else "#chat", + image="https://wave.h2o.ai/img/h2o-logo.svg", + items=[ + ui.nav_group( + "Menu", + items=[ + ui.nav_item(name="#datasets", label="Upload Dataset", icon="Database"), + ui.nav_item(name="#chat", label="Chat", icon="Chat"), + ui.nav_item(name="#settings", label="Settings", icon="Settings") + ], + ), + ui.nav_group( + "Help", + items=[ + ui.nav_item(name="#documentation", label="Documentation", icon="TextDocument"), + ui.nav_item(name="#support", label="Support", icon="Telemarketer"), + ], + ), + ], + secondary_items=[ + ui.persona( + title=username, + size="xs", + image=profile_pic, + ), + ], + ) + + # Connect to LLM + openai.api_key = "" + + await user_variable(q) + await client_variable(q) + # If no active hash present, render chat. + if q.args["#"] is None: + await chat(q) + + +def on_shutdown(): + logging.info("App stopped. Goodbye!") + + +# Preload sample data for the app +def upload_demo_examples(q: Q): + upload_action = True + cur_dir = os.getcwd() + sample_data_path = f"{cur_dir}/examples/demo" + org_table_name = "Sleep health and lifestyle study" + usr_table_name = org_table_name.lower().replace(" ", "_") + + table_metadata_path = f"{tmp_path}/data/tables.json" + # Do not upload dataset if user had any tables uploaded previously. This check avoids re-uploading sample dataset. + if os.path.exists(table_metadata_path): + # Read the existing content from the JSON file + with open(table_metadata_path, "r") as json_file: + existing_data = json.load(json_file) + if usr_table_name in existing_data: + upload_action = False + logging.info(f"Dataset already uploaded, skipping upload!") + + if upload_action: + table_metadata = dict() + table_metadata[usr_table_name] = { + "original_name": org_table_name, + "schema_info_path": f"{sample_data_path}/table_info.jsonl", + "samples_path": f"{sample_data_path}/sleep_health_and_lifestyle_dataset.csv", + "samples_qa": None, + } + update_tables(f"{tmp_path}/data/tables.json", table_metadata) + + q.client.org_table_name = org_table_name + q.client.table_name = usr_table_name + q.client.table_samples_path = f"{sample_data_path}/sleep_health_and_lifestyle_dataset.csv" + q.client.table_info_path = f"{sample_data_path}/table_info.jsonl" + q.client.sample_qna_path = None + + _, db_resp = db_setup( + db_name=q.client.db_name, + hostname=q.client.host_name, + user_name=q.client.user_name, + password=q.client.password, + port=q.client.port, + table_info_path=q.client.table_info_path, + table_samples_path=q.client.table_samples_path, + table_name=q.client.table_name, + ) + logging.info(f"DB updated with demo examples: \n {db_resp}") + q.args.table_dropdown = usr_table_name + + +async def on_event(q: Q): + event_handled = False + args_dict = expando_to_dict(q.args) + logging.info(f"Args dict {args_dict}") + + if q.args.regenerate_with_options: + q.args.chatbot = "try harder" + elif q.args.regenerate: + q.args.chatbot = "regenerate" + q.client.eval_mode = False + + if q.args.suggest: + q.args.chatbot = "Recommend questions" + await chatbot(q) + event_handled = True + if q.args.eval_mode: + q.client.eval_mode = True + q.client.model_choices = MODEL_CHOICE_MAP_EVAL_MODE + await chat(q) + event_handled = True + if q.args.table_dropdown and not q.args.chatbot and q.client.table_name != q.args.table_dropdown: + logging.info(f"User selected table: {q.args.table_dropdown}") + await submit_table(q) + q.args.chatbot = f"Table {q.args.table_dropdown} selected" + # Refresh response is triggered when user selects a table via dropdown + event_handled = True + if ( + q.args.model_choice_dropdown + and not q.args.chatbot and q.args.model_choice_dropdown != q.client.model_choice_dropdown + ): + logging.info(f"User selected model type: {q.args.model_choice_dropdown}") + q.client.model_choice_dropdown = q.args.model_choice_dropdown + q.page["select_tables"].model_choice_dropdown.value = q.client.model_choice_dropdown + q.args.chatbot = f"Model {q.client.model_choice_dropdown} selected" + # Refresh response is triggered when user selects a table via dropdown + q.args.model_choice_dropdown = None + event_handled = True + if q.args.task_dropdown and not q.args.chatbot and q.client.task_dropdown != q.args.task_dropdown: + logging.info(f"User selected task: {q.args.task_dropdown}") + q.client.task_dropdown = q.args.task_dropdown + q.page["task_choice"].task_dropdown.value = q.client.task_dropdown + q.args.chatbot = f"'{TASK_CHOICE[q.client.task_dropdown]}' mode selected" + q.args.task_dropdown = None + # Refresh response is triggered when user selects a table via dropdown + event_handled = True + if ( + q.args.save_conversation + or q.args.save_rejected_conversation + or (q.args.chatbot and "save the qna pair:" in q.args.chatbot.lower()) + ): + question = q.client.query + _val = q.client.llm_response + # Currently, any manual input by the user is a Question by default + table_name = q.client.table_name if q.client.table_name else "default" + _is_invalid = True if q.args.save_rejected_conversation else False + _msg = ( + "Conversation saved successfully!" + if not _is_invalid + else "Sorry, we couldn't get it right, we will try to improve!" + ) + if ( + question is not None + and "SELECT" in question + and (question.lower().startswith("question:") or question.lower().startswith("q:")) + ): + _q = question.lower().split("q:")[1].split("r:")[0].strip() + _r = question.lower().split("r:")[1].strip() + logging.info(f"Saving conversation for question: {_q} and response: {_r}") + save_query(base_path, table_name, query=_q, response=_r, is_invalid=_is_invalid) + elif question is not None and _val is not None and _val.strip() != "": + logging.info(f"Saving conversation for question: {question} and response: {_val}") + save_query(base_path, table_name, query=question, response=_val, is_invalid=_is_invalid) + else: + _msg = "Sorry, try generating a conversation to save." + q.page["chat_card"].data += [_msg, False] + event_handled = True + elif q.args.download_accept: + result_path = f"{base_path}/var/lib/tmp/.cache/{q.client.table_name}/history.jsonl" + # Check if path exists + # If the model selected is GPT models from openAI then disable download + # We don't want to use those for further improvements externally. + if Path(result_path).exists() and "gpt-4" not in q.client.model_choice_dropdown and "gpt-3.5-turbo" not in q.client.model_choice_dropdown: + logging.info(f"Downloading accepted QnA history for table: {q.client.table_name}") + (server_path,) = await q.site.upload([result_path]) + q.page["meta"].script = ui.inline_script(f'window.open("{server_path}", "_blank");') + os.remove(result_path) + _msg = "Download complete!" + else: + _msg = "No history found!" + q.page["chat_card"].data += [_msg, False] + event_handled = True + elif q.args.download_reject and "gpt-4" not in q.client.model_choice_dropdown and "gpt-3.5" not in q.client.model_choice_dropdown: + logging.info(f"Downloading rejected QnA history for table: {q.client.table_name}") + result_path = f"{base_path}/var/lib/tmp/.cache/{q.client.table_name}/invalid/history.jsonl" + if Path(result_path).exists(): + (server_path,) = await q.site.upload([result_path]) + q.page["meta"].script = ui.inline_script(f'window.open("{server_path}", "_blank");') + os.remove(result_path) + _msg = "Download complete!" + else: + _msg = "No history found!" + q.page["chat_card"].data += [_msg, False] + event_handled = True + elif q.args.regenerate or q.args.regenerate_with_options: + await chatbot(q) + event_handled = True + elif q.args.demo_mode: + logging.info(f"Switching to demo mode!") + # If demo datasets are not present, register them. + upload_demo_examples(q) + logging.info(f"Demo dataset selected: {q.client.table_name}") + await submit_table(q) + sample_qs = """ + Data description: The Sleep Health and Lifestyle Dataset comprises 400 rows and 13 columns, + covering a wide range of variables related to sleep and daily habits. + It includes details such as gender, age, occupation, sleep duration, quality of sleep, + physical activity level, stress levels, BMI category, blood pressure, heart rate, daily steps, + and the presence or absence of sleep disorders\n + Reference: https://www.kaggle.com/datasets/uom190346a/sleep-health-and-lifestyle-dataset \n + Example questions:\n + 1. Describe data/Preview data. Tip: For more detailed insights on the data try AutoInsights on the Cloud marketplace. + 2. What is the average sleep duration for each gender? + 3. How does average sleep duration vary across different age groups? + 4. What are the most common occupations among individuals in the dataset? + 5. What is the average sleep duration for each occupation? + 6. What is the average sleep duration for each age group? + 7. What is the effect of Physical Activity Level on Quality of Sleep? + """ + q.args.chatbot = ( + f"Demo mode is enabled.\nTry below example questions for the selected data to get started,\n{sample_qs}" + ) + q.page["chat_card"].data += [q.args.chatbot, False] + q.args.table_dropdown = None + q.args.model_choice_dropdown = None + q.args.task_dropdown = None + await chat(q) + event_handled = True + else: # default chatbot event + await handle_on(q) + event_handled = True + logging.info(f"Event handled: {event_handled} ... ") + return event_handled + + +# Record analytics and usage +def heap_analytics(userid, event_properties=None) -> ui.inline_script: + + if "HEAP_ID" not in os.environ: + return + + heap_id = os.getenv("HEAP_ID") + script = f""" +window.heap=window.heap||[],heap.load=function(e,t){{window.heap.appid=e,window.heap.config=t=t||{{}};var r=document.createElement("script");r.type="text/javascript",r.async=!0,r.src="https://cdn.heapanalytics.com/js/heap-"+e+".js";var a=document.getElementsByTagName("script")[0];a.parentNode.insertBefore(r,a);for(var n=function(e){{return function(){{heap.push([e].concat(Array.prototype.slice.call(arguments,0)))}}}},p=["addEventProperties","addUserProperties","clearEventProperties","identify","resetIdentity","removeEventProperty","setEventProperties","track","unsetEventProperty"],o=0;o