Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,4 @@ ml-models.iml
ml-models.ipr

.DS_Store
.databricks
31 changes: 2 additions & 29 deletions dash-chatbot-app/DatabricksChatbot.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,13 @@
import dash
from dash import html, Input, Output, State, dcc
import dash_bootstrap_components as dbc
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole

from model_serving_utils import query_endpoint

class DatabricksChatbot:
def __init__(self, app, endpoint_name, height='600px'):
self.app = app
self.endpoint_name = endpoint_name
self.height = height

try:
print('Initializing WorkspaceClient...')
self.w = WorkspaceClient()
print('WorkspaceClient initialized successfully')
except Exception as e:
print(f'Error initializing WorkspaceClient: {str(e)}')
self.w = None

self.layout = self._create_layout()
self._create_callbacks()
self._add_custom_css()
Expand Down Expand Up @@ -111,25 +100,9 @@ def clear_chat(n_clicks):
return dash.no_update, dash.no_update

def _call_model_endpoint(self, messages, max_tokens=128):
if self.w is None:
raise Exception('WorkspaceClient is not initialized')

chat_messages = [
ChatMessage(
content=message['content'],
role=ChatMessageRole[message['role'].upper()]
) for message in messages
]
try:
print('Calling model endpoint...')
response = self.w.serving_endpoints.query(
name=self.endpoint_name,
messages=chat_messages,
max_tokens=max_tokens
)
message = response.choices[0].message.content
print('Model endpoint called successfully')
return message
return query_endpoint(self.endpoint_name, messages, max_tokens)["content"]
except Exception as e:
print(f'Error calling model endpoint: {str(e)}')
raise
Expand Down
3 changes: 1 addition & 2 deletions dash-chatbot-app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import dash
import dash_bootstrap_components as dbc
from DatabricksChatbot import DatabricksChatbot

# Ensure environment variable is set correctly
serving_endpoint = os.getenv('SERVING_ENDPOINT')
assert serving_endpoint, 'SERVING_ENDPOINT must be set in app.yaml.'
Expand All @@ -21,4 +20,4 @@
], fluid=True)

if __name__ == '__main__':
app.run_server(debug=True)
app.run(debug=True)
Copy link
Contributor

@jerrylian-db jerrylian-db Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! The template is actually broke in production right now because dash 3+ does not support the run_server method anymore. 😅

2 changes: 1 addition & 1 deletion dash-chatbot-app/app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ command: [

env:
- name: "SERVING_ENDPOINT"
valueFrom: "serving-endpoint"
valueFrom: "serving_endpoint"
19 changes: 19 additions & 0 deletions dash-chatbot-app/model_serving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from mlflow.deployments import get_deploy_client

def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]:
"""Calls a model serving endpoint."""
res = get_deploy_client('databricks').predict(
endpoint=endpoint_name,
inputs={'messages': messages, "max_tokens": max_tokens},
)
if "messages" in res:
return res["messages"]
elif "choices" in res:
return [res["choices"][0]["message"]]
raise Exception("This app can only run against:"
"1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)"
"2) Databricks agent serving endpoints that implement the conversational agent schema documented "
"in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent")

def query_endpoint(endpoint_name, messages, max_tokens):
return _query_endpoint(endpoint_name, messages, max_tokens)[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def doesn't have to be done here, but should we add a test to ensure that this file is the same across the diff chatbot templates to reduce room for potential error? cc @jerrylian-db

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a good idea - it seems we don't have CI set up yet in this repo, but we should, and probably should have some tests for these apps

Copy link
Contributor Author

@smurching smurching Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basic tests are probably sufficient (we can just follow best practices for testing a dash, streamlit, gradio app etc) and hopefully not too confusing to give to end users (we could also leave the tests outside the template, I'm fine with that too, but IMO they seem actually useful)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can send a follow-up with that change after this one, will discuss offline with @aakrati @jerrylian-db on what makes sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(filed #22 on top of the current PR, can iterate/follow up on this one after this initial PR lands)

8 changes: 4 additions & 4 deletions dash-chatbot-app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
dash
dash-bootstrap-components
databricks-sdk
python-dotenv
dash==3.0.2
dash-bootstrap-components==2.0.0
mlflow>=2.21.2
python-dotenv==1.1.0
27 changes: 15 additions & 12 deletions gradio-chatbot-app/app.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
import gradio as gr
import logging
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
import os
from model_serving_utils import query_endpoint

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize the Databricks Workspace Client
workspace_client = WorkspaceClient()

# Ensure environment variable is set correctly
assert os.getenv('SERVING_ENDPOINT'), "SERVING_ENDPOINT must be set in app.yaml."

def query_llm(message, history):
"""
Query the LLM with the given message and chat history.
`message`: str - the latest user input.
`history`: list of dicts - OpenAI-style messages.
"""
if not message.strip():
return "ERROR: The question should not be empty"

prompt = "Answer this question like a helpful assistant: "
messages = [ChatMessage(role=ChatMessageRole.USER, content=prompt + message)]
# Convert from Gradio-style history to OpenAI-style messages
message_history = []
for user_msg, assistant_msg in history:
message_history.append({"role": "user", "content": user_msg})
message_history.append({"role": "assistant", "content": assistant_msg})

# Add the latest user message
message_history.append({"role": "user", "content": message})

try:
logger.info(f"Sending request to model endpoint: {os.getenv('SERVING_ENDPOINT')}")
response = workspace_client.serving_endpoints.query(
name=os.getenv('SERVING_ENDPOINT'),
messages=messages,
response = query_endpoint(
endpoint_name=os.getenv('SERVING_ENDPOINT'),
messages=message_history,
max_tokens=400
)
logger.info("Received response from model endpoint")
return response.choices[0].message.content
return response["content"]
except Exception as e:
logger.error(f"Error querying model: {str(e)}", exc_info=True)
return f"Error: {str(e)}"
Expand Down
2 changes: 1 addition & 1 deletion gradio-chatbot-app/app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ command: [

env:
- name: "SERVING_ENDPOINT"
valueFrom: "serving-endpoint"
valueFrom: "serving_endpoint"
19 changes: 19 additions & 0 deletions gradio-chatbot-app/model_serving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from mlflow.deployments import get_deploy_client

def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]:
"""Calls a model serving endpoint."""
res = get_deploy_client('databricks').predict(
endpoint=endpoint_name,
inputs={'messages': messages, "max_tokens": max_tokens},
)
if "messages" in res:
return res["messages"]
elif "choices" in res:
return [res["choices"][0]["message"]]
raise Exception("This app can only run against:"
"1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)"
"2) Databricks agent serving endpoints that implement the conversational agent schema documented "
"in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent")

def query_endpoint(endpoint_name, messages, max_tokens):
return _query_endpoint(endpoint_name, messages, max_tokens)[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to document that we will only take the last msg of the list of chatagent messages?

4 changes: 2 additions & 2 deletions gradio-chatbot-app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
gradio
databricks-sdk>=0.1.0
gradio==5.23.3
mlflow>=2.21.2
22 changes: 5 additions & 17 deletions shiny-chatbot-app/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Shiny for Python LLM Chat Example with Databricks
import os
from databricks.sdk import config
from openai import AsyncOpenAI
from shiny import App, ui, reactive

from model_serving_utils import query_endpoint
# Ensure environment variable is set correctly
assert os.getenv("SERVING_ENDPOINT"), "SERVING_ENDPOINT must be set in app.yaml."

Expand All @@ -20,16 +18,6 @@
)

def server(input, output, session):
# Application is using credentials via the `databricks.sdk`
cfg = config.Config()

# `openai` library can be configured to use Databricks model serving
# Databricks model endpoints are openai compatible
llm = AsyncOpenAI(
api_key='',
base_url=f"https://{cfg.hostname}/serving-endpoints",
default_headers=cfg.authenticate()
)

chat = ui.Chat(id="chat", messages=[])

Expand All @@ -43,12 +31,12 @@ async def _():
@chat.on_user_submit
async def _():
messages = chat.messages(format="openai")
response = await llm.chat.completions.create(
model=os.getenv("SERVING_ENDPOINT"),
response = query_endpoint(
endpoint_name=os.getenv("SERVING_ENDPOINT"),
messages=messages,
stream=True
max_tokens=400
)
await chat.append_message_stream(response)
await chat.append_message(response)

app = App(app_ui, server)

Expand Down
2 changes: 1 addition & 1 deletion shiny-chatbot-app/app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ command: [

env:
- name: "SERVING_ENDPOINT"
valueFrom: "serving-endpoint"
valueFrom: "serving_endpoint"
20 changes: 20 additions & 0 deletions shiny-chatbot-app/model_serving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from mlflow.deployments import get_deploy_client

def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]:
"""Calls a model serving endpoint."""
res = get_deploy_client('databricks').predict(
endpoint=endpoint_name,
inputs={'messages': messages, "max_tokens": max_tokens},
)
if "messages" in res:
return res["messages"]
elif "choices" in res:
return [res["choices"][0]["message"]]
raise Exception("This app can only run against:"
"1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)"
"2) Databricks agent serving endpoints that implement the conversational agent schema documented "
"in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent")


def query_endpoint(endpoint_name, messages, max_tokens):
return _query_endpoint(endpoint_name, messages, max_tokens)[-1]
6 changes: 3 additions & 3 deletions shiny-chatbot-app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
shiny==1.0.0
databricks-sdk
tokenizers
openai
mlflow>=2.21.2
tokenizers==0.21.1
openai==1.70.0
26 changes: 8 additions & 18 deletions streamlit-chatbot-app/app.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import logging
import os
import streamlit as st
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
from model_serving_utils import query_endpoint

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize the Databricks Workspace Client
w = WorkspaceClient()

# Ensure environment variable is set correctly
assert os.getenv('SERVING_ENDPOINT'), "SERVING_ENDPOINT must be set in app.yaml."

Expand Down Expand Up @@ -49,22 +45,16 @@ def get_user_info():
with st.chat_message("user"):
st.markdown(prompt)

messages = [ChatMessage(role=ChatMessageRole.SYSTEM, content="You are a helpful assistant."),
ChatMessage(role=ChatMessageRole.USER, content=prompt)]

# Display assistant response in chat message container
with st.chat_message("assistant"):
# Query the Databricks serving endpoint
try:
response = w.serving_endpoints.query(
name=os.getenv("SERVING_ENDPOINT"),
messages=messages,
max_tokens=400,
)
assistant_response = response.choices[0].message.content
st.markdown(assistant_response)
except Exception as e:
st.error(f"Error querying model: {e}")
assistant_response = query_endpoint(
endpoint_name=os.getenv("SERVING_ENDPOINT"),
messages=st.session_state.messages,
max_tokens=400,
)["content"]
st.markdown(assistant_response)


# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": assistant_response})
2 changes: 1 addition & 1 deletion streamlit-chatbot-app/app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ env:
- name: STREAMLIT_BROWSER_GATHER_USAGE_STATS
value: "false"
- name: "SERVING_ENDPOINT"
valueFrom: "serving-endpoint"
valueFrom: "serving_endpoint"
24 changes: 24 additions & 0 deletions streamlit-chatbot-app/model_serving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from mlflow.deployments import get_deploy_client

def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]:
"""Calls a model serving endpoint."""
res = get_deploy_client('databricks').predict(
endpoint=endpoint_name,
inputs={'messages': messages, "max_tokens": max_tokens},
)
if "messages" in res:
return res["messages"]
elif "choices" in res:
return [res["choices"][0]["message"]]
raise Exception("This app can only run against:"
"1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)"
"2) Databricks agent serving endpoints that implement the conversational agent schema documented "
"in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent")

def query_endpoint(endpoint_name, messages, max_tokens):
"""
Query a chat-completions or agent serving endpoint
If querying an agent serving endpoint that returns multiple messages, this method
returns the last message
."""
return _query_endpoint(endpoint_name, messages, max_tokens)[-1]
3 changes: 2 additions & 1 deletion streamlit-chatbot-app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
openai
mlflow>=2.21.2
streamlit==1.44.1