Skip to content

Commit

Permalink
Merge pull request #46 from oracle-samples/cdb-fix
Browse files Browse the repository at this point in the history
Cdb fix
  • Loading branch information
corradodebari authored Nov 20, 2024
2 parents cec3d87 + 74b47dd commit 8286235
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 19 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ To run the application on bare-metal; download the [source](https://github.com/o
pip3 install -r app/requirements.txt
```

1. Exit from your shell and run again:

```bash
source .venv/bin/activate
```

1. Start Streamlit:

```bash
Expand Down
7 changes: 5 additions & 2 deletions app/src/content/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ def display_logs():
st.chat_message("human").write(msg["message"])
else:
if state.rag_params["enable"]:
st.chat_message("ai").write(msg["answer"])
st_common.show_rag_refs(msg["context"])
logger.info("msg[\"answer\"]")
logger.info(msg)
st.chat_message("ai").write(msg)
if "context" in msg and msg["context"]:
st_common.show_rag_refs(msg["context"])
else:
st.chat_message("ai").write(msg.content)
except api_server.queue.Empty:
Expand Down
6 changes: 2 additions & 4 deletions app/src/content/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,6 @@ def main():
st_common.set_default_state("test_set", None)
st_common.set_default_state("temp_dir", None)

# TO CHECK
# file_path = "temp_file.csv"
# if os.path.exists(file_path):
# os.remove(file_path)

if state.toggle_generate_test:
st.header("Q&A Test Dataset Generation")
Expand Down Expand Up @@ -256,6 +252,8 @@ def main():
# Generate Q&A
qa_file = os.path.join(state["temp_dir"], f"{file_name}_{str(qa_count)}_test_set.json")
state.qa_file = qa_file
logger.info("calling generate with client: ")
logger.info(llm_client)
state.test_set = utilities.generate_qa(qa_file, kb, qa_count, client=llm_client)
placeholder.empty()
st.success("Q&A Generation Succeeded.", icon="✅")
Expand Down
15 changes: 11 additions & 4 deletions app/src/modules/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,14 @@ def do_POST(self): # pylint: disable=invalid-name
try:
# Parse the POST data as JSON
post_json = json.loads(post_data)
messages = post_json.get("messages")

# Extract the content of 'user'
message = next(item['content'] for item in messages if item['role'] == 'user')

logger.info(messages)
logger.info(message)

# Extract the 'message' field from the JSON
message = post_json.get("message")
response = None
if message:
# Log the incoming message
Expand All @@ -123,6 +128,8 @@ def do_POST(self): # pylint: disable=invalid-name
stream=False,
)
self.send_response(200)
logger.info("RESPONSE:")
logger.info(response)
# Process response to JSON
else:
json_response = {"error": "No 'message' field found in request."}
Expand Down Expand Up @@ -163,9 +170,9 @@ def do_POST(self): # pylint: disable=invalid-name
for i in range(max_items):
chunk = full_context[i]
sources.add(os.path.basename(chunk.metadata["source"]))
json_response = {"answer": response["answer"], "sources": list(sources)}
json_response = {"choices": [{"text":response["answer"],"index":0}], "sources": list(sources)}
else:
json_response = {"answer": response.content}
json_response = {"choices": [{"text":response["answer"],"index":0}]}

self.wfile.write(json.dumps(json_response).encode("utf-8"))

Expand Down
10 changes: 5 additions & 5 deletions app/src/modules/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def record_update():
"question": state.question_input,
"reference_answer": state.reference_answer_input,
"reference_context": state.reference_context_input,
"conversation_history": "",
"conversation_history": [],
}
for key, value in new_data.items():
state.df.at[index, key] = value
Expand All @@ -129,12 +129,12 @@ def record_update():
# Text input for the question, storing the user's input in the session state
state.index_output = st.write("Record: " + str(state.index + 1) + "/" + str(state.df.shape[0]))
state.hide_input = st.checkbox("Hide", value=state.hide_input)
state.question_input = st.text_area("question", height=1, value=state.question_input)
state.reference_answer_input = st.text_area("Reference answer", height=1, value=state.reference_answer_input)
state.question_input = st.text_area("question", height=100, value=state.question_input)
state.reference_answer_input = st.text_area("Reference answer", height=100, value=state.reference_answer_input)
state.reference_context_input = st.text_area(
"Reference context", height=10, value=state.reference_context_input, disabled=True
"Reference context", height=100, value=state.reference_context_input, disabled=True
)
state.metadata_input = st.text_area("Metadata", height=1, value=state.metadata_input, disabled=True)
state.metadata_input = st.text_area("Metadata", height=100, value=state.metadata_input, disabled=True)

if save_clicked:
st.success("Q&A saved successfully!")
Expand Down
5 changes: 3 additions & 2 deletions app/src/modules/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def get_ll_model(model, ll_models_config=None, giskarded=False):
lm_params = ll_models_config[model]

logger.info(
"Configuring LLM - URL: %s; Temp - %s; Max Tokens - %s",
"Configuring LLM - Model: %s; URL: %s; Temp - %s; Max Tokens - %s",
model,
lm_params["url"],
lm_params["temperature"][0],
lm_params["max_tokens"][0],
Expand Down Expand Up @@ -115,7 +116,7 @@ def get_ll_model(model, ll_models_config=None, giskarded=False):
elif llm_api == "ChatPerplexity":
client = ChatPerplexity(pplx_api_key=lm_params["api_key"], model_kwargs=common_params)
elif llm_api == "ChatOllama":
client = ChatOllama(base_url=lm_params["url"], model_kwargs=common_params)
client = ChatOllama(model=model,base_url=lm_params["url"], model_kwargs=common_params)
## End - Add Additional Model Authentication Here
api_accessible, err_msg = is_url_accessible(llm_url)

Expand Down
12 changes: 10 additions & 2 deletions widget/app.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,17 @@ document.addEventListener('DOMContentLoaded', () => {
async function getBotResponse(userMessage) {
const apiUrl = chatBotUrl;
const messagePayload = {
message: userMessage
model: "",
messages: [
{
role: "user",
content: userMessage
}
]
};



try {
const response = await fetch(apiUrl, {
method: 'POST',
Expand All @@ -63,7 +70,8 @@ document.addEventListener('DOMContentLoaded', () => {
});

const data = await response.json();
return data.choices[0].message.content;
//return data.answer;
return data.choices[0].text+"\n Source: "+data.sources[0];
} catch (error) {
console.error('Error fetching API:', error);
return "Sorry, I couldn't connect to the server.";
Expand Down

0 comments on commit 8286235

Please sign in to comment.