Skip to content

Commit

Permalink
Chore: table data parameter for DB retrieval is now inputted via txt …
Browse files Browse the repository at this point in the history
…file
  • Loading branch information
mateoperezrivera committed Jun 13, 2024
1 parent e8cdd79 commit 70f88cc
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 8 deletions.
1 change: 1 addition & 0 deletions db_table_info.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"Table1" : "Explain information stored in table,columns and connections with other tables", "Table2" : "Explain information stored in table,columns and connections with other tables"}
3 changes: 1 addition & 2 deletions orc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@ async def main(req: func.HttpRequest) -> func.HttpResponse:
'id': client_principal_id,
'name': client_principal_name
}
db_table_info= req_body.get('db_table_info')

if question:

result = await orchestrator.run(conversation_id, question, client_principal,db_table_info)
result = await orchestrator.run(conversation_id, question, client_principal)

return func.HttpResponse(json.dumps(result), mimetype="application/json", status_code=200)
else:
Expand Down
4 changes: 2 additions & 2 deletions orc/code_orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
DB_RETRIEVAL = os.environ.get("DB_RETRIEVAL") or "true"
DB_RETRIEVAL = True if DB_RETRIEVAL.lower() == "true" else False

async def get_answer(history,db_table_info):
async def get_answer(history):


#############################
Expand Down Expand Up @@ -196,7 +196,7 @@ async def get_answer(history,db_table_info):

#run sql retrieval function
if(DB_RETRIEVAL):
db_function_result= await kernel.invoke(retrievalPlugin["DBRetrieval"], sk.KernelArguments(input=search_query,db_table_info=db_table_info))
db_function_result= await kernel.invoke(retrievalPlugin["DBRetrieval"], sk.KernelArguments(input=search_query))
formatted_sources = db_function_result.value[:100].replace('\n', ' ')
escaped_sources = escape_xml_characters(db_function_result.value)
db_sources=escaped_sources
Expand Down
4 changes: 2 additions & 2 deletions orc/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_credentials():
# return DefaultAzureCredential(exclude_managed_identity_credential=is_local_env, exclude_environment_credential=is_local_env)
return DefaultAzureCredential()

async def run(conversation_id, ask, client_principal,db_table_info):
async def run(conversation_id, ask, client_principal):

start_time = time.time()

Expand Down Expand Up @@ -70,7 +70,7 @@ async def run(conversation_id, ask, client_principal,db_table_info):

# get rag answer and sources
logging.info(f"[orchestrator] executing RAG retrieval using code orchestration")
answer_dict = await code_orchestration.get_answer(history,db_table_info)
answer_dict = await code_orchestration.get_answer(history)

# 3) update and save conversation (containing history and conversation data)

Expand Down
15 changes: 13 additions & 2 deletions orc/plugins/Retrieval/native_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ def BingRetrieval(
name="DBRetrieval",
)
def DBRetrieval(self,
input: Annotated[str, "The user question"],
db_table_info: Annotated[str, "The tables to search for information"]
input: Annotated[str, "The user question"]
) -> Annotated[str, "the output is a string with the search results"]:
logging.info('Python HTTP trigger function processed a request.')

Expand All @@ -223,7 +222,19 @@ def DBRetrieval(self,
logging.error(f"[DBRetrieval] Invalid db_type specified")
return ""
azureOpenAIKey = get_secret("azureOpenAIKey")
#Get table data
try:

# Get table information from file
with open('db_table_info.txt', 'r') as file:
db_table_info = file.read()

except FileNotFoundError:
logging.error("[DBRetrieval] db_table_info.txt not found")
return ""
except Exception as e:
logging.error(f"[DBRetrieval] Unexpected error: {e}")
return ""
# Log configuration variables
logging.info(f"[{DB_TYPE} Retrieval] Server: {DB_SERVER}")
logging.info(f"[{DB_TYPE} Retrieval] Database: {DB_DATABASE}")
Expand Down

0 comments on commit 70f88cc

Please sign in to comment.