diff --git a/.gitignore b/.gitignore index 4237253..8ef54fe 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,4 @@ ml-models.iml ml-models.ipr .DS_Store +.databricks diff --git a/dash-chatbot-app/DatabricksChatbot.py b/dash-chatbot-app/DatabricksChatbot.py index 56121ec..5162961 100644 --- a/dash-chatbot-app/DatabricksChatbot.py +++ b/dash-chatbot-app/DatabricksChatbot.py @@ -1,24 +1,13 @@ import dash from dash import html, Input, Output, State, dcc import dash_bootstrap_components as dbc -from databricks.sdk import WorkspaceClient -from databricks.sdk.service.serving import ChatMessage, ChatMessageRole - +from model_serving_utils import query_endpoint class DatabricksChatbot: def __init__(self, app, endpoint_name, height='600px'): self.app = app self.endpoint_name = endpoint_name self.height = height - - try: - print('Initializing WorkspaceClient...') - self.w = WorkspaceClient() - print('WorkspaceClient initialized successfully') - except Exception as e: - print(f'Error initializing WorkspaceClient: {str(e)}') - self.w = None - self.layout = self._create_layout() self._create_callbacks() self._add_custom_css() @@ -111,25 +100,9 @@ def clear_chat(n_clicks): return dash.no_update, dash.no_update def _call_model_endpoint(self, messages, max_tokens=128): - if self.w is None: - raise Exception('WorkspaceClient is not initialized') - - chat_messages = [ - ChatMessage( - content=message['content'], - role=ChatMessageRole[message['role'].upper()] - ) for message in messages - ] try: print('Calling model endpoint...') - response = self.w.serving_endpoints.query( - name=self.endpoint_name, - messages=chat_messages, - max_tokens=max_tokens - ) - message = response.choices[0].message.content - print('Model endpoint called successfully') - return message + return query_endpoint(self.endpoint_name, messages, max_tokens)["content"] except Exception as e: print(f'Error calling model endpoint: {str(e)}') raise diff --git a/dash-chatbot-app/app.py b/dash-chatbot-app/app.py index d113c23..ca609e5 100644 --- a/dash-chatbot-app/app.py +++ b/dash-chatbot-app/app.py @@ -2,7 +2,6 @@ import dash import dash_bootstrap_components as dbc from DatabricksChatbot import DatabricksChatbot - # Ensure environment variable is set correctly serving_endpoint = os.getenv('SERVING_ENDPOINT') assert serving_endpoint, 'SERVING_ENDPOINT must be set in app.yaml.' @@ -21,4 +20,4 @@ ], fluid=True) if __name__ == '__main__': - app.run_server(debug=True) + app.run(debug=True) diff --git a/dash-chatbot-app/app.yaml b/dash-chatbot-app/app.yaml index 4ebce16..3f99b50 100644 --- a/dash-chatbot-app/app.yaml +++ b/dash-chatbot-app/app.yaml @@ -5,4 +5,4 @@ command: [ env: - name: "SERVING_ENDPOINT" - valueFrom: "serving-endpoint" + valueFrom: "serving_endpoint" diff --git a/dash-chatbot-app/model_serving_utils.py b/dash-chatbot-app/model_serving_utils.py new file mode 100644 index 0000000..964c938 --- /dev/null +++ b/dash-chatbot-app/model_serving_utils.py @@ -0,0 +1,19 @@ +from mlflow.deployments import get_deploy_client + +def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]: + """Calls a model serving endpoint.""" + res = get_deploy_client('databricks').predict( + endpoint=endpoint_name, + inputs={'messages': messages, "max_tokens": max_tokens}, + ) + if "messages" in res: + return res["messages"] + elif "choices" in res: + return [res["choices"][0]["message"]] + raise Exception("This app can only run against:" + "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)" + "2) Databricks agent serving endpoints that implement the conversational agent schema documented " + "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent") + +def query_endpoint(endpoint_name, messages, max_tokens): + return _query_endpoint(endpoint_name, messages, max_tokens)[-1] diff --git a/dash-chatbot-app/requirements.txt b/dash-chatbot-app/requirements.txt index b4ca85d..08008da 100644 --- a/dash-chatbot-app/requirements.txt +++ b/dash-chatbot-app/requirements.txt @@ -1,4 +1,4 @@ -dash -dash-bootstrap-components -databricks-sdk -python-dotenv +dash==3.0.2 +dash-bootstrap-components==2.0.0 +mlflow>=2.21.2 +python-dotenv==1.1.0 diff --git a/gradio-chatbot-app/app.py b/gradio-chatbot-app/app.py index ad26f03..5d39071 100644 --- a/gradio-chatbot-app/app.py +++ b/gradio-chatbot-app/app.py @@ -1,38 +1,41 @@ import gradio as gr import logging -from databricks.sdk import WorkspaceClient -from databricks.sdk.service.serving import ChatMessage, ChatMessageRole import os +from model_serving_utils import query_endpoint # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Initialize the Databricks Workspace Client -workspace_client = WorkspaceClient() - # Ensure environment variable is set correctly assert os.getenv('SERVING_ENDPOINT'), "SERVING_ENDPOINT must be set in app.yaml." def query_llm(message, history): """ Query the LLM with the given message and chat history. + `message`: str - the latest user input. + `history`: list of dicts - OpenAI-style messages. """ if not message.strip(): return "ERROR: The question should not be empty" - prompt = "Answer this question like a helpful assistant: " - messages = [ChatMessage(role=ChatMessageRole.USER, content=prompt + message)] + # Convert from Gradio-style history to OpenAI-style messages + message_history = [] + for user_msg, assistant_msg in history: + message_history.append({"role": "user", "content": user_msg}) + message_history.append({"role": "assistant", "content": assistant_msg}) + + # Add the latest user message + message_history.append({"role": "user", "content": message}) try: logger.info(f"Sending request to model endpoint: {os.getenv('SERVING_ENDPOINT')}") - response = workspace_client.serving_endpoints.query( - name=os.getenv('SERVING_ENDPOINT'), - messages=messages, + response = query_endpoint( + endpoint_name=os.getenv('SERVING_ENDPOINT'), + messages=message_history, max_tokens=400 ) - logger.info("Received response from model endpoint") - return response.choices[0].message.content + return response["content"] except Exception as e: logger.error(f"Error querying model: {str(e)}", exc_info=True) return f"Error: {str(e)}" diff --git a/gradio-chatbot-app/app.yaml b/gradio-chatbot-app/app.yaml index 05c6146..3645c2d 100644 --- a/gradio-chatbot-app/app.yaml +++ b/gradio-chatbot-app/app.yaml @@ -5,4 +5,4 @@ command: [ env: - name: "SERVING_ENDPOINT" - valueFrom: "serving-endpoint" + valueFrom: "serving_endpoint" diff --git a/gradio-chatbot-app/model_serving_utils.py b/gradio-chatbot-app/model_serving_utils.py new file mode 100644 index 0000000..964c938 --- /dev/null +++ b/gradio-chatbot-app/model_serving_utils.py @@ -0,0 +1,19 @@ +from mlflow.deployments import get_deploy_client + +def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]: + """Calls a model serving endpoint.""" + res = get_deploy_client('databricks').predict( + endpoint=endpoint_name, + inputs={'messages': messages, "max_tokens": max_tokens}, + ) + if "messages" in res: + return res["messages"] + elif "choices" in res: + return [res["choices"][0]["message"]] + raise Exception("This app can only run against:" + "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)" + "2) Databricks agent serving endpoints that implement the conversational agent schema documented " + "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent") + +def query_endpoint(endpoint_name, messages, max_tokens): + return _query_endpoint(endpoint_name, messages, max_tokens)[-1] diff --git a/gradio-chatbot-app/requirements.txt b/gradio-chatbot-app/requirements.txt index e8477e8..352e530 100644 --- a/gradio-chatbot-app/requirements.txt +++ b/gradio-chatbot-app/requirements.txt @@ -1,2 +1,2 @@ -gradio -databricks-sdk>=0.1.0 \ No newline at end of file +gradio==5.23.3 +mlflow>=2.21.2 diff --git a/shiny-chatbot-app/app.py b/shiny-chatbot-app/app.py index f0fe419..c7c31b7 100644 --- a/shiny-chatbot-app/app.py +++ b/shiny-chatbot-app/app.py @@ -1,9 +1,7 @@ # Shiny for Python LLM Chat Example with Databricks import os -from databricks.sdk import config -from openai import AsyncOpenAI from shiny import App, ui, reactive - +from model_serving_utils import query_endpoint # Ensure environment variable is set correctly assert os.getenv("SERVING_ENDPOINT"), "SERVING_ENDPOINT must be set in app.yaml." @@ -20,16 +18,6 @@ ) def server(input, output, session): - # Application is using credentials via the `databricks.sdk` - cfg = config.Config() - - # `openai` library can be configured to use Databricks model serving - # Databricks model endpoints are openai compatible - llm = AsyncOpenAI( - api_key='', - base_url=f"https://{cfg.hostname}/serving-endpoints", - default_headers=cfg.authenticate() - ) chat = ui.Chat(id="chat", messages=[]) @@ -43,12 +31,12 @@ async def _(): @chat.on_user_submit async def _(): messages = chat.messages(format="openai") - response = await llm.chat.completions.create( - model=os.getenv("SERVING_ENDPOINT"), + response = query_endpoint( + endpoint_name=os.getenv("SERVING_ENDPOINT"), messages=messages, - stream=True + max_tokens=400 ) - await chat.append_message_stream(response) + await chat.append_message(response) app = App(app_ui, server) diff --git a/shiny-chatbot-app/app.yaml b/shiny-chatbot-app/app.yaml index ae9cc1a..b2cb59e 100644 --- a/shiny-chatbot-app/app.yaml +++ b/shiny-chatbot-app/app.yaml @@ -9,4 +9,4 @@ command: [ env: - name: "SERVING_ENDPOINT" - valueFrom: "serving-endpoint" + valueFrom: "serving_endpoint" diff --git a/shiny-chatbot-app/model_serving_utils.py b/shiny-chatbot-app/model_serving_utils.py new file mode 100644 index 0000000..3dfa56e --- /dev/null +++ b/shiny-chatbot-app/model_serving_utils.py @@ -0,0 +1,20 @@ +from mlflow.deployments import get_deploy_client + +def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]: + """Calls a model serving endpoint.""" + res = get_deploy_client('databricks').predict( + endpoint=endpoint_name, + inputs={'messages': messages, "max_tokens": max_tokens}, + ) + if "messages" in res: + return res["messages"] + elif "choices" in res: + return [res["choices"][0]["message"]] + raise Exception("This app can only run against:" + "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)" + "2) Databricks agent serving endpoints that implement the conversational agent schema documented " + "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent") + + +def query_endpoint(endpoint_name, messages, max_tokens): + return _query_endpoint(endpoint_name, messages, max_tokens)[-1] diff --git a/shiny-chatbot-app/requirements.txt b/shiny-chatbot-app/requirements.txt index 5d9cb62..28854dc 100644 --- a/shiny-chatbot-app/requirements.txt +++ b/shiny-chatbot-app/requirements.txt @@ -1,4 +1,4 @@ shiny==1.0.0 -databricks-sdk -tokenizers -openai \ No newline at end of file +mlflow>=2.21.2 +tokenizers==0.21.1 +openai==1.70.0 diff --git a/streamlit-chatbot-app/app.py b/streamlit-chatbot-app/app.py index b390b4f..1632b39 100644 --- a/streamlit-chatbot-app/app.py +++ b/streamlit-chatbot-app/app.py @@ -1,16 +1,12 @@ import logging import os import streamlit as st -from databricks.sdk import WorkspaceClient -from databricks.sdk.service.serving import ChatMessage, ChatMessageRole +from model_serving_utils import query_endpoint # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Initialize the Databricks Workspace Client -w = WorkspaceClient() - # Ensure environment variable is set correctly assert os.getenv('SERVING_ENDPOINT'), "SERVING_ENDPOINT must be set in app.yaml." @@ -49,22 +45,16 @@ def get_user_info(): with st.chat_message("user"): st.markdown(prompt) - messages = [ChatMessage(role=ChatMessageRole.SYSTEM, content="You are a helpful assistant."), - ChatMessage(role=ChatMessageRole.USER, content=prompt)] - # Display assistant response in chat message container with st.chat_message("assistant"): # Query the Databricks serving endpoint - try: - response = w.serving_endpoints.query( - name=os.getenv("SERVING_ENDPOINT"), - messages=messages, - max_tokens=400, - ) - assistant_response = response.choices[0].message.content - st.markdown(assistant_response) - except Exception as e: - st.error(f"Error querying model: {e}") + assistant_response = query_endpoint( + endpoint_name=os.getenv("SERVING_ENDPOINT"), + messages=st.session_state.messages, + max_tokens=400, + )["content"] + st.markdown(assistant_response) + # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": assistant_response}) diff --git a/streamlit-chatbot-app/app.yaml b/streamlit-chatbot-app/app.yaml index d81e38a..8f0e050 100644 --- a/streamlit-chatbot-app/app.yaml +++ b/streamlit-chatbot-app/app.yaml @@ -8,4 +8,4 @@ env: - name: STREAMLIT_BROWSER_GATHER_USAGE_STATS value: "false" - name: "SERVING_ENDPOINT" - valueFrom: "serving-endpoint" + valueFrom: "serving_endpoint" diff --git a/streamlit-chatbot-app/model_serving_utils.py b/streamlit-chatbot-app/model_serving_utils.py new file mode 100644 index 0000000..49703fe --- /dev/null +++ b/streamlit-chatbot-app/model_serving_utils.py @@ -0,0 +1,24 @@ +from mlflow.deployments import get_deploy_client + +def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]: + """Calls a model serving endpoint.""" + res = get_deploy_client('databricks').predict( + endpoint=endpoint_name, + inputs={'messages': messages, "max_tokens": max_tokens}, + ) + if "messages" in res: + return res["messages"] + elif "choices" in res: + return [res["choices"][0]["message"]] + raise Exception("This app can only run against:" + "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)" + "2) Databricks agent serving endpoints that implement the conversational agent schema documented " + "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent") + +def query_endpoint(endpoint_name, messages, max_tokens): + """ + Query a chat-completions or agent serving endpoint + If querying an agent serving endpoint that returns multiple messages, this method + returns the last message + .""" + return _query_endpoint(endpoint_name, messages, max_tokens)[-1] diff --git a/streamlit-chatbot-app/requirements.txt b/streamlit-chatbot-app/requirements.txt index f0dd0ae..6dc0a94 100644 --- a/streamlit-chatbot-app/requirements.txt +++ b/streamlit-chatbot-app/requirements.txt @@ -1 +1,2 @@ -openai \ No newline at end of file +mlflow>=2.21.2 +streamlit==1.44.1