-
Notifications
You must be signed in to change notification settings - Fork 68
Update chatbot templates to support agent serving endpoints #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3b3ee13
36cb2b0
1f20160
699cbc4
20fc12c
3bff47c
8da25a8
bade341
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,3 +138,4 @@ ml-models.iml | |
| ml-models.ipr | ||
|
|
||
| .DS_Store | ||
| .databricks | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,4 +5,4 @@ command: [ | |
|
|
||
| env: | ||
| - name: "SERVING_ENDPOINT" | ||
| valueFrom: "serving-endpoint" | ||
| valueFrom: "serving_endpoint" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| from mlflow.deployments import get_deploy_client | ||
|
|
||
| def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]: | ||
| """Calls a model serving endpoint.""" | ||
| res = get_deploy_client('databricks').predict( | ||
| endpoint=endpoint_name, | ||
| inputs={'messages': messages, "max_tokens": max_tokens}, | ||
| ) | ||
| if "messages" in res: | ||
| return res["messages"] | ||
| elif "choices" in res: | ||
| return [res["choices"][0]["message"]] | ||
| raise Exception("This app can only run against:" | ||
| "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)" | ||
| "2) Databricks agent serving endpoints that implement the conversational agent schema documented " | ||
| "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent") | ||
|
|
||
| def query_endpoint(endpoint_name, messages, max_tokens): | ||
| return _query_endpoint(endpoint_name, messages, max_tokens)[-1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. def doesn't have to be done here, but should we add a test to ensure that this file is the same across the diff chatbot templates to reduce room for potential error? cc @jerrylian-db
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's a good idea - it seems we don't have CI set up yet in this repo, but we should, and probably should have some tests for these apps
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basic tests are probably sufficient (we can just follow best practices for testing a dash, streamlit, gradio app etc) and hopefully not too confusing to give to end users (we could also leave the tests outside the template, I'm fine with that too, but IMO they seem actually useful)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can send a follow-up with that change after this one, will discuss offline with @aakrati @jerrylian-db on what makes sense
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (filed #22 on top of the current PR, can iterate/follow up on this one after this initial PR lands) |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| dash | ||
| dash-bootstrap-components | ||
| databricks-sdk | ||
| python-dotenv | ||
| dash==3.0.2 | ||
| dash-bootstrap-components==2.0.0 | ||
| mlflow>=2.21.2 | ||
| python-dotenv==1.1.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,4 +5,4 @@ command: [ | |
|
|
||
| env: | ||
| - name: "SERVING_ENDPOINT" | ||
| valueFrom: "serving-endpoint" | ||
| valueFrom: "serving_endpoint" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,19 @@ | ||
| from mlflow.deployments import get_deploy_client | ||
|
|
||
| def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]: | ||
| """Calls a model serving endpoint.""" | ||
| res = get_deploy_client('databricks').predict( | ||
| endpoint=endpoint_name, | ||
| inputs={'messages': messages, "max_tokens": max_tokens}, | ||
| ) | ||
| if "messages" in res: | ||
| return res["messages"] | ||
| elif "choices" in res: | ||
| return [res["choices"][0]["message"]] | ||
| raise Exception("This app can only run against:" | ||
| "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)" | ||
| "2) Databricks agent serving endpoints that implement the conversational agent schema documented " | ||
| "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent") | ||
|
|
||
| def query_endpoint(endpoint_name, messages, max_tokens): | ||
| return _query_endpoint(endpoint_name, messages, max_tokens)[-1] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be possible to document that we will only take the last msg of the list of chatagent messages? |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,2 +1,2 @@ | ||
| gradio | ||
| databricks-sdk>=0.1.0 | ||
| gradio==5.23.3 | ||
| mlflow>=2.21.2 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,4 +9,4 @@ command: [ | |
|
|
||
| env: | ||
| - name: "SERVING_ENDPOINT" | ||
| valueFrom: "serving-endpoint" | ||
| valueFrom: "serving_endpoint" | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from mlflow.deployments import get_deploy_client | ||
|
|
||
| def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]: | ||
| """Calls a model serving endpoint.""" | ||
| res = get_deploy_client('databricks').predict( | ||
| endpoint=endpoint_name, | ||
| inputs={'messages': messages, "max_tokens": max_tokens}, | ||
| ) | ||
| if "messages" in res: | ||
| return res["messages"] | ||
| elif "choices" in res: | ||
| return [res["choices"][0]["message"]] | ||
| raise Exception("This app can only run against:" | ||
| "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)" | ||
| "2) Databricks agent serving endpoints that implement the conversational agent schema documented " | ||
| "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent") | ||
|
|
||
|
|
||
| def query_endpoint(endpoint_name, messages, max_tokens): | ||
| return _query_endpoint(endpoint_name, messages, max_tokens)[-1] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| shiny==1.0.0 | ||
| databricks-sdk | ||
| tokenizers | ||
| openai | ||
| mlflow>=2.21.2 | ||
| tokenizers==0.21.1 | ||
| openai==1.70.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| from mlflow.deployments import get_deploy_client | ||
|
|
||
| def _query_endpoint(endpoint_name: str, messages: list[dict[str, str]], max_tokens) -> list[dict[str, str]]: | ||
| """Calls a model serving endpoint.""" | ||
| res = get_deploy_client('databricks').predict( | ||
| endpoint=endpoint_name, | ||
| inputs={'messages': messages, "max_tokens": max_tokens}, | ||
| ) | ||
| if "messages" in res: | ||
| return res["messages"] | ||
| elif "choices" in res: | ||
| return [res["choices"][0]["message"]] | ||
| raise Exception("This app can only run against:" | ||
| "1) Databricks foundation model or external model endpoints with the chat task type (described in https://docs.databricks.com/aws/en/machine-learning/model-serving/score-foundation-models#chat-completion-model-query)" | ||
| "2) Databricks agent serving endpoints that implement the conversational agent schema documented " | ||
| "in https://docs.databricks.com/aws/en/generative-ai/agent-framework/author-agent") | ||
|
|
||
| def query_endpoint(endpoint_name, messages, max_tokens): | ||
| """ | ||
| Query a chat-completions or agent serving endpoint | ||
| If querying an agent serving endpoint that returns multiple messages, this method | ||
| returns the last message | ||
| .""" | ||
| return _query_endpoint(endpoint_name, messages, max_tokens)[-1] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| openai | ||
| mlflow>=2.21.2 | ||
| streamlit==1.44.1 |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this! The template is actually broke in production right now because dash 3+ does not support the
run_servermethod anymore. 😅