diff --git a/e2e-chatbot-app/app.yaml b/e2e-chatbot-app/app.yaml index d81e38a..0b309da 100644 --- a/e2e-chatbot-app/app.yaml +++ b/e2e-chatbot-app/app.yaml @@ -7,5 +7,8 @@ command: [ env: - name: STREAMLIT_BROWSER_GATHER_USAGE_STATS value: "false" + - name: "OBO_ENABLED" + value: False - name: "SERVING_ENDPOINT" valueFrom: "serving-endpoint" + # value: "" diff --git a/e2e-chatbot-app/model_serving_utils.py b/e2e-chatbot-app/model_serving_utils.py index bdb6eb6..b2189a5 100644 --- a/e2e-chatbot-app/model_serving_utils.py +++ b/e2e-chatbot-app/model_serving_utils.py @@ -2,6 +2,8 @@ from databricks.sdk import WorkspaceClient import json import uuid +import streamlit as st +import os import logging @@ -11,10 +13,29 @@ level=logging.DEBUG ) +# For OBO-user make sure this flag is updated in the app.yaml file +is_OBO = os.getenv("OBO_ENABLED") + +def _get_workpace_client(): + if is_OBO: + user_access_token = st.context.headers.get('X-Forwarded-Access-Token') + os.environ["DATABRICKS_TOKEN"] = user_access_token + return WorkspaceClient(token=user_access_token, auth_type="pat") + else: + return WorkspaceClient() + +def _get_mlflow_client(): + if is_OBO: + user_access_token = st.context.headers.get('X-Forwarded-Access-Token') + os.environ["DATABRICKS_AUTH_TYPE"] = "pat" + os.environ["DATABRICKS_TOKEN"] = user_access_token + return get_deploy_client("databricks") + + def _get_endpoint_task_type(endpoint_name: str) -> str: """Get the task type of a serving endpoint.""" try: - w = WorkspaceClient() + w = _get_workpace_client() ep = w.serving_endpoints.get(endpoint_name) return ep.task if ep.task else "chat/completions" except Exception: @@ -75,7 +96,7 @@ def query_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], re def _query_chat_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], return_traces: bool): """Invoke an endpoint that implements either chat completions or ChatAgent and stream the response""" - client = get_deploy_client("databricks") + client = _get_mlflow_client() # Prepare input payload inputs = { @@ -94,7 +115,7 @@ def _query_chat_endpoint_stream(endpoint_name: str, messages: list[dict[str, str def _query_responses_endpoint_stream(endpoint_name: str, messages: list[dict[str, str]], return_traces: bool): """Stream responses from agent/v1/responses endpoints using MLflow deployments client.""" - client = get_deploy_client("databricks") + client = _get_mlflow_client() input_messages = _convert_to_responses_format(messages) @@ -128,8 +149,8 @@ def _query_chat_endpoint(endpoint_name, messages, return_traces): inputs = {'messages': messages} if return_traces: inputs['databricks_options'] = {'return_trace': True} - - res = get_deploy_client('databricks').predict( + + res = _get_mlflow_client().predict( endpoint=endpoint_name, inputs=inputs, ) @@ -142,7 +163,7 @@ def _query_chat_endpoint(endpoint_name, messages, return_traces): def _query_responses_endpoint(endpoint_name, messages, return_traces): """Query agent/v1/responses endpoints using MLflow deployments client.""" - client = get_deploy_client("databricks") + client = _get_mlflow_client() input_messages = _convert_to_responses_format(messages) @@ -238,7 +259,7 @@ def submit_feedback(endpoint, request_id, rating): } ] } - w = WorkspaceClient() + w =_get_workpace_client() return w.api_client.do( method='POST', path=f"/serving-endpoints/{endpoint}/served-models/feedback/invocations", @@ -247,6 +268,6 @@ def submit_feedback(endpoint, request_id, rating): def endpoint_supports_feedback(endpoint_name): - w = WorkspaceClient() + w = _get_workpace_client() endpoint = w.serving_endpoints.get(endpoint_name) return "feedback" in [entity.name for entity in endpoint.config.served_entities] diff --git a/e2e-chatbot-app/requirements.txt b/e2e-chatbot-app/requirements.txt index 6dc0a94..eace1f4 100644 --- a/e2e-chatbot-app/requirements.txt +++ b/e2e-chatbot-app/requirements.txt @@ -1,2 +1,2 @@ -mlflow>=2.21.2 +mlflow>=2.22.1 streamlit==1.44.1