Skip to content

Commit

Permalink
Improve schema of CosmosDB chat history to handle long conversations (#…
Browse files Browse the repository at this point in the history
…2312)

* Configure Azure Developer Pipeline

* CosmosDB v2 changes

* CosmosDB progress

* Fix CosmosDB API

* Revert unneeded changes

* Fix tests

* Rename message to message_pair

* Address Matt's feedback

* Add version env var

* Reformat with latest black

* Minor updates and test fix

* Changes based on Marks call

* Fix the frontend for the HistoryList API
  • Loading branch information
pamelafox authored Jan 29, 2025
1 parent a891ab3 commit 7a2044a
Show file tree
Hide file tree
Showing 14 changed files with 268 additions and 154 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.9.3
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 24.10.0
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-prettier
Expand Down
115 changes: 81 additions & 34 deletions app/backend/chat_history/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from azure.cosmos.aio import ContainerProxy, CosmosClient
from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential
from quart import Blueprint, current_app, jsonify, request
from quart import Blueprint, current_app, jsonify, make_response, request

from config import (
CONFIG_CHAT_HISTORY_COSMOS_ENABLED,
CONFIG_COSMOS_HISTORY_CLIENT,
CONFIG_COSMOS_HISTORY_CONTAINER,
CONFIG_COSMOS_HISTORY_VERSION,
CONFIG_CREDENTIAL,
)
from decorators import authenticated
Expand All @@ -34,23 +35,50 @@ async def post_chat_history(auth_claims: Dict[str, Any]):

try:
request_json = await request.get_json()
id = request_json.get("id")
answers = request_json.get("answers")
title = answers[0][0][:50] + "..." if len(answers[0][0]) > 50 else answers[0][0]
session_id = request_json.get("id")
message_pairs = request_json.get("answers")
first_question = message_pairs[0][0]
title = first_question + "..." if len(first_question) > 50 else first_question
timestamp = int(time.time() * 1000)

await container.upsert_item(
{"id": id, "entra_oid": entra_oid, "title": title, "answers": answers, "timestamp": timestamp}
)
# Insert the session item:
session_item = {
"id": session_id,
"version": current_app.config[CONFIG_COSMOS_HISTORY_VERSION],
"session_id": session_id,
"entra_oid": entra_oid,
"type": "session",
"title": title,
"timestamp": timestamp,
}

message_pair_items = []
# Now insert a message item for each question/response pair:
for ind, message_pair in enumerate(message_pairs):
message_pair_items.append(
{
"id": f"{session_id}-{ind}",
"version": current_app.config[CONFIG_COSMOS_HISTORY_VERSION],
"session_id": session_id,
"entra_oid": entra_oid,
"type": "message_pair",
"question": message_pair[0],
"response": message_pair[1],
}
)

batch_operations = [("upsert", (session_item,))] + [
("upsert", (message_pair_item,)) for message_pair_item in message_pair_items
]
await container.execute_item_batch(batch_operations=batch_operations, partition_key=[entra_oid, session_id])
return jsonify({}), 201
except Exception as error:
return error_response(error, "/chat_history")


@chat_history_cosmosdb_bp.post("/chat_history/items")
@chat_history_cosmosdb_bp.get("/chat_history/sessions")
@authenticated
async def get_chat_history(auth_claims: Dict[str, Any]):
async def get_chat_history_sessions(auth_claims: Dict[str, Any]):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -63,27 +91,26 @@ async def get_chat_history(auth_claims: Dict[str, Any]):
return jsonify({"error": "User OID not found"}), 401

try:
request_json = await request.get_json()
count = request_json.get("count", 20)
continuation_token = request_json.get("continuation_token")
count = int(request.args.get("count", 10))
continuation_token = request.args.get("continuation_token")

res = container.query_items(
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC",
parameters=[dict(name="@entra_oid", value=entra_oid)],
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid AND c.type = @type ORDER BY c.timestamp DESC",
parameters=[dict(name="@entra_oid", value=entra_oid), dict(name="@type", value="session")],
partition_key=[entra_oid],
max_item_count=count,
)

# set the continuation token for the next page
pager = res.by_page(continuation_token)

# Get the first page, and the continuation token
sessions = []
try:
page = await pager.__anext__()
continuation_token = pager.continuation_token # type: ignore

items = []
async for item in page:
items.append(
sessions.append(
{
"id": item.get("id"),
"entra_oid": item.get("entra_oid"),
Expand All @@ -94,18 +121,17 @@ async def get_chat_history(auth_claims: Dict[str, Any]):

# If there are no more pages, StopAsyncIteration is raised
except StopAsyncIteration:
items = []
continuation_token = None

return jsonify({"items": items, "continuation_token": continuation_token}), 200
return jsonify({"sessions": sessions, "continuation_token": continuation_token}), 200

except Exception as error:
return error_response(error, "/chat_history/items")
return error_response(error, "/chat_history/sessions")


@chat_history_cosmosdb_bp.get("/chat_history/items/<item_id>")
@chat_history_cosmosdb_bp.get("/chat_history/sessions/<session_id>")
@authenticated
async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
async def get_chat_history_session(auth_claims: Dict[str, Any], session_id: str):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -118,26 +144,34 @@ async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
return jsonify({"error": "User OID not found"}), 401

try:
res = await container.read_item(item=item_id, partition_key=entra_oid)
res = container.query_items(
query="SELECT * FROM c WHERE c.session_id = @session_id AND c.type = @type",
parameters=[dict(name="@session_id", value=session_id), dict(name="@type", value="message_pair")],
partition_key=[entra_oid, session_id],
)

message_pairs = []
async for page in res.by_page():
async for item in page:
message_pairs.append([item["question"], item["response"]])

return (
jsonify(
{
"id": res.get("id"),
"entra_oid": res.get("entra_oid"),
"title": res.get("title", "untitled"),
"timestamp": res.get("timestamp"),
"answers": res.get("answers", []),
"id": session_id,
"entra_oid": entra_oid,
"answers": message_pairs,
}
),
200,
)
except Exception as error:
return error_response(error, f"/chat_history/items/{item_id}")
return error_response(error, f"/chat_history/sessions/{session_id}")


@chat_history_cosmosdb_bp.delete("/chat_history/items/<item_id>")
@chat_history_cosmosdb_bp.delete("/chat_history/sessions/<session_id>")
@authenticated
async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
async def delete_chat_history_session(auth_claims: Dict[str, Any], session_id: str):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -150,10 +184,22 @@ async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str)
return jsonify({"error": "User OID not found"}), 401

try:
await container.delete_item(item=item_id, partition_key=entra_oid)
return jsonify({}), 204
res = container.query_items(
query="SELECT c.id FROM c WHERE c.session_id = @session_id",
parameters=[dict(name="@session_id", value=session_id)],
partition_key=[entra_oid, session_id],
)

ids_to_delete = []
async for page in res.by_page():
async for item in page:
ids_to_delete.append(item["id"])

batch_operations = [("delete", (id,)) for id in ids_to_delete]
await container.execute_item_batch(batch_operations=batch_operations, partition_key=[entra_oid, session_id])
return await make_response("", 204)
except Exception as error:
return error_response(error, f"/chat_history/items/{item_id}")
return error_response(error, f"/chat_history/sessions/{session_id}")


@chat_history_cosmosdb_bp.before_app_serving
Expand Down Expand Up @@ -183,6 +229,7 @@ async def setup_clients():

current_app.config[CONFIG_COSMOS_HISTORY_CLIENT] = cosmos_client
current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] = cosmos_container
current_app.config[CONFIG_COSMOS_HISTORY_VERSION] = os.environ["AZURE_CHAT_HISTORY_VERSION"]


@chat_history_cosmosdb_bp.after_app_serving
Expand Down
1 change: 1 addition & 0 deletions app/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
CONFIG_CHAT_HISTORY_COSMOS_ENABLED = "chat_history_cosmos_enabled"
CONFIG_COSMOS_HISTORY_CLIENT = "cosmos_history_client"
CONFIG_COSMOS_HISTORY_CONTAINER = "cosmos_history_container"
CONFIG_COSMOS_HISTORY_VERSION = "cosmos_history_version"
2 changes: 1 addition & 1 deletion app/backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ azure-core==1.30.2
# msrest
azure-core-tracing-opentelemetry==1.0.0b11
# via azure-monitor-opentelemetry
azure-cosmos==4.7.0
azure-cosmos==4.9.0
# via -r requirements.in
azure-identity==1.17.1
# via
Expand Down
25 changes: 13 additions & 12 deletions app/frontend/src/api/api.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const BACKEND_URI = "";

import { ChatAppResponse, ChatAppResponseOrError, ChatAppRequest, Config, SimpleAPIResponse, HistoryListApiResponse, HistroyApiResponse } from "./models";
import { ChatAppResponse, ChatAppResponseOrError, ChatAppRequest, Config, SimpleAPIResponse, HistoryListApiResponse, HistoryApiResponse } from "./models";
import { useLogin, getToken, isUsingAppServicesLogin } from "../authConfig";

export async function getHeaders(idToken: string | undefined): Promise<Record<string, string>> {
Expand Down Expand Up @@ -145,10 +145,14 @@ export async function postChatHistoryApi(item: any, idToken: string): Promise<an

export async function getChatHistoryListApi(count: number, continuationToken: string | undefined, idToken: string): Promise<HistoryListApiResponse> {
const headers = await getHeaders(idToken);
const response = await fetch("/chat_history/items", {
method: "POST",
headers: { ...headers, "Content-Type": "application/json" },
body: JSON.stringify({ count: count, continuation_token: continuationToken })
let url = `${BACKEND_URI}/chat_history/sessions?count=${count}`;
if (continuationToken) {
url += `&continuationToken=${continuationToken}`;
}

const response = await fetch(url.toString(), {
method: "GET",
headers: { ...headers, "Content-Type": "application/json" }
});

if (!response.ok) {
Expand All @@ -159,9 +163,9 @@ export async function getChatHistoryListApi(count: number, continuationToken: st
return dataResponse;
}

export async function getChatHistoryApi(id: string, idToken: string): Promise<HistroyApiResponse> {
export async function getChatHistoryApi(id: string, idToken: string): Promise<HistoryApiResponse> {
const headers = await getHeaders(idToken);
const response = await fetch(`/chat_history/items/${id}`, {
const response = await fetch(`/chat_history/sessions/${id}`, {
method: "GET",
headers: { ...headers, "Content-Type": "application/json" }
});
Expand All @@ -170,21 +174,18 @@ export async function getChatHistoryApi(id: string, idToken: string): Promise<Hi
throw new Error(`Getting chat history failed: ${response.statusText}`);
}

const dataResponse: HistroyApiResponse = await response.json();
const dataResponse: HistoryApiResponse = await response.json();
return dataResponse;
}

export async function deleteChatHistoryApi(id: string, idToken: string): Promise<any> {
const headers = await getHeaders(idToken);
const response = await fetch(`/chat_history/items/${id}`, {
const response = await fetch(`/chat_history/sessions/${id}`, {
method: "DELETE",
headers: { ...headers, "Content-Type": "application/json" }
});

if (!response.ok) {
throw new Error(`Deleting chat history failed: ${response.statusText}`);
}

const dataResponse: any = await response.json();
return dataResponse;
}
6 changes: 2 additions & 4 deletions app/frontend/src/api/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export interface SpeechConfig {
}

export type HistoryListApiResponse = {
items: {
sessions: {
id: string;
entra_oid: string;
title: string;
Expand All @@ -116,10 +116,8 @@ export type HistoryListApiResponse = {
continuation_token?: string;
};

export type HistroyApiResponse = {
export type HistoryApiResponse = {
id: string;
entra_oid: string;
title: string;
answers: any;
timestamp: number;
};
8 changes: 4 additions & 4 deletions app/frontend/src/components/HistoryProviders/CosmosDB.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ export class CosmosDBProvider implements IHistoryProvider {
if (!this.continuationToken) {
this.isItemEnd = true;
}
return response.items.map(item => ({
id: item.id,
title: item.title,
timestamp: item.timestamp
return response.sessions.map(session => ({
id: session.id,
title: session.title,
timestamp: session.timestamp
}));
} catch (e) {
console.error(e);
Expand Down
Loading

0 comments on commit 7a2044a

Please sign in to comment.