Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions e2e-chatbot-app/app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@ command: [
env:
- name: STREAMLIT_BROWSER_GATHER_USAGE_STATS
value: "false"
- name: "OBO_ENABLED"
value: False
- name: "SERVING_ENDPOINT"
valueFrom: "serving-endpoint"
# value: "<if OBO-user, comment out above line, uncomment this, and enter your serving endpoint name here>"
37 changes: 29 additions & 8 deletions e2e-chatbot-app/model_serving_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from databricks.sdk import WorkspaceClient
import json
import uuid
import streamlit as st
import os

import logging

Expand All @@ -11,10 +13,29 @@
level=logging.DEBUG
)

# For OBO-user make sure this flag is updated in the app.yaml file
is_OBO = os.getenv("OBO_ENABLED")

def _get_workpace_client():
if is_OBO:
user_access_token = st.context.headers.get('X-Forwarded-Access-Token')
os.environ["DATABRICKS_TOKEN"] = user_access_token
return WorkspaceClient(token=user_access_token, auth_type="pat")
else:
return WorkspaceClient()

def _get_mlflow_client():
if is_OBO:
user_access_token = st.context.headers.get('X-Forwarded-Access-Token')
os.environ["DATABRICKS_AUTH_TYPE"] = "pat"
os.environ["DATABRICKS_TOKEN"] = user_access_token
return get_deploy_client("databricks")


def _get_endpoint_task_type(endpoint_name: str) -> str:
"""Get the task type of a serving endpoint."""
try:
w = WorkspaceClient()
w = _get_workpace_client()
ep = w.serving_endpoints.get(endpoint_name)
return ep.task if ep.task else "chat/completions"
except Exception:
Expand Down Expand Up @@ -75,7 +96,7 @@ def query_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], re

def _query_chat_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], return_traces: bool):
"""Invoke an endpoint that implements either chat completions or ChatAgent and stream the response"""
client = get_deploy_client("databricks")
client = _get_mlflow_client()

# Prepare input payload
inputs = {
Expand All @@ -94,7 +115,7 @@ def _query_chat_endpoint_stream(endpoint_name: str, messages: list[dict[str, str

def _query_responses_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], return_traces: bool):
"""Stream responses from agent/v1/responses endpoints using MLflow deployments client."""
client = get_deploy_client("databricks")
client = _get_mlflow_client()

input_messages = _convert_to_responses_format(messages)

Expand Down Expand Up @@ -128,8 +149,8 @@ def _query_chat_endpoint(endpoint_name, messages, return_traces):
inputs = {'messages': messages}
if return_traces:
inputs['databricks_options'] = {'return_trace': True}
res = get_deploy_client('databricks').predict(

res = _get_mlflow_client().predict(
endpoint=endpoint_name,
inputs=inputs,
)
Expand All @@ -142,7 +163,7 @@ def _query_chat_endpoint(endpoint_name, messages, return_traces):

def _query_responses_endpoint(endpoint_name, messages, return_traces):
"""Query agent/v1/responses endpoints using MLflow deployments client."""
client = get_deploy_client("databricks")
client = _get_mlflow_client()

input_messages = _convert_to_responses_format(messages)

Expand Down Expand Up @@ -238,7 +259,7 @@ def submit_feedback(endpoint, request_id, rating):
}
]
}
w = WorkspaceClient()
w =_get_workpace_client()
return w.api_client.do(
method='POST',
path=f"/serving-endpoints/{endpoint}/served-models/feedback/invocations",
Expand All @@ -247,6 +268,6 @@ def submit_feedback(endpoint, request_id, rating):


def endpoint_supports_feedback(endpoint_name):
w = WorkspaceClient()
w = _get_workpace_client()
endpoint = w.serving_endpoints.get(endpoint_name)
return "feedback" in [entity.name for entity in endpoint.config.served_entities]
2 changes: 1 addition & 1 deletion e2e-chatbot-app/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mlflow>=2.21.2
mlflow>=2.22.1
streamlit==1.44.1