Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve schema of CosmosDB chat history to handle long conversations #2312

Merged
merged 14 commits into from
Jan 29, 2025
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.2
rev: v0.9.3
hooks:
- id: ruff
- repo: https://github.com/psf/black
rev: 24.10.0
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-prettier
Expand Down
115 changes: 81 additions & 34 deletions app/backend/chat_history/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from azure.cosmos.aio import ContainerProxy, CosmosClient
from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential
from quart import Blueprint, current_app, jsonify, request
from quart import Blueprint, current_app, jsonify, make_response, request

from config import (
CONFIG_CHAT_HISTORY_COSMOS_ENABLED,
CONFIG_COSMOS_HISTORY_CLIENT,
CONFIG_COSMOS_HISTORY_CONTAINER,
CONFIG_COSMOS_HISTORY_VERSION,
CONFIG_CREDENTIAL,
)
from decorators import authenticated
Expand All @@ -34,23 +35,50 @@ async def post_chat_history(auth_claims: Dict[str, Any]):

try:
request_json = await request.get_json()
id = request_json.get("id")
answers = request_json.get("answers")
title = answers[0][0][:50] + "..." if len(answers[0][0]) > 50 else answers[0][0]
session_id = request_json.get("id")
message_pairs = request_json.get("answers")
first_question = message_pairs[0][0]
title = first_question + "..." if len(first_question) > 50 else first_question
timestamp = int(time.time() * 1000)

await container.upsert_item(
{"id": id, "entra_oid": entra_oid, "title": title, "answers": answers, "timestamp": timestamp}
)
# Insert the session item:
session_item = {
pamelafox marked this conversation as resolved.
Show resolved Hide resolved
"id": session_id,
"version": current_app.config[CONFIG_COSMOS_HISTORY_VERSION],
"session_id": session_id,
"entra_oid": entra_oid,
"type": "session",
"title": title,
"timestamp": timestamp,
}

message_pair_items = []
# Now insert a message item for each question/response pair:
for ind, message_pair in enumerate(message_pairs):
message_pair_items.append(
{
"id": f"{session_id}-{ind}",
pamelafox marked this conversation as resolved.
Show resolved Hide resolved
"version": current_app.config[CONFIG_COSMOS_HISTORY_VERSION],
"session_id": session_id,
"entra_oid": entra_oid,
"type": "message_pair",
"question": message_pair[0],
"response": message_pair[1],
}
)

batch_operations = [("upsert", (session_item,))] + [
("upsert", (message_pair_item,)) for message_pair_item in message_pair_items
]
await container.execute_item_batch(batch_operations=batch_operations, partition_key=[entra_oid, session_id])
return jsonify({}), 201
except Exception as error:
return error_response(error, "/chat_history")


@chat_history_cosmosdb_bp.post("/chat_history/items")
@chat_history_cosmosdb_bp.get("/chat_history/sessions")
@authenticated
async def get_chat_history(auth_claims: Dict[str, Any]):
async def get_chat_history_sessions(auth_claims: Dict[str, Any]):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -63,27 +91,26 @@ async def get_chat_history(auth_claims: Dict[str, Any]):
return jsonify({"error": "User OID not found"}), 401

try:
request_json = await request.get_json()
count = request_json.get("count", 20)
continuation_token = request_json.get("continuation_token")
count = int(request.args.get("count", 10))
continuation_token = request.args.get("continuation_token")

res = container.query_items(
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC",
parameters=[dict(name="@entra_oid", value=entra_oid)],
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid AND c.type = @type ORDER BY c.timestamp DESC",
parameters=[dict(name="@entra_oid", value=entra_oid), dict(name="@type", value="session")],
partition_key=[entra_oid],
max_item_count=count,
)

# set the continuation token for the next page
pager = res.by_page(continuation_token)

# Get the first page, and the continuation token
sessions = []
try:
page = await pager.__anext__()
continuation_token = pager.continuation_token # type: ignore

items = []
async for item in page:
items.append(
sessions.append(
{
"id": item.get("id"),
"entra_oid": item.get("entra_oid"),
Expand All @@ -94,18 +121,17 @@ async def get_chat_history(auth_claims: Dict[str, Any]):

# If there are no more pages, StopAsyncIteration is raised
except StopAsyncIteration:
items = []
continuation_token = None

return jsonify({"items": items, "continuation_token": continuation_token}), 200
return jsonify({"sessions": sessions, "continuation_token": continuation_token}), 200

except Exception as error:
return error_response(error, "/chat_history/items")
return error_response(error, "/chat_history/sessions")


@chat_history_cosmosdb_bp.get("/chat_history/items/<item_id>")
@chat_history_cosmosdb_bp.get("/chat_history/sessions/<session_id>")
@authenticated
async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
async def get_chat_history_session(auth_claims: Dict[str, Any], session_id: str):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -118,26 +144,34 @@ async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
return jsonify({"error": "User OID not found"}), 401

try:
res = await container.read_item(item=item_id, partition_key=entra_oid)
res = container.query_items(
query="SELECT * FROM c WHERE c.session_id = @session_id AND c.type = @type",
parameters=[dict(name="@session_id", value=session_id), dict(name="@type", value="message_pair")],
partition_key=[entra_oid, session_id],
)

message_pairs = []
async for page in res.by_page():
async for item in page:
message_pairs.append([item["question"], item["response"]])

return (
jsonify(
{
"id": res.get("id"),
"entra_oid": res.get("entra_oid"),
"title": res.get("title", "untitled"),
"timestamp": res.get("timestamp"),
"answers": res.get("answers", []),
"id": session_id,
"entra_oid": entra_oid,
"answers": message_pairs,
}
),
200,
)
except Exception as error:
return error_response(error, f"/chat_history/items/{item_id}")
return error_response(error, f"/chat_history/sessions/{session_id}")


@chat_history_cosmosdb_bp.delete("/chat_history/items/<item_id>")
@chat_history_cosmosdb_bp.delete("/chat_history/sessions/<session_id>")
@authenticated
async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
async def delete_chat_history_session(auth_claims: Dict[str, Any], session_id: str):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -150,10 +184,22 @@ async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str)
return jsonify({"error": "User OID not found"}), 401

try:
await container.delete_item(item=item_id, partition_key=entra_oid)
return jsonify({}), 204
res = container.query_items(
query="SELECT c.id FROM c WHERE c.session_id = @session_id",
parameters=[dict(name="@session_id", value=session_id)],
partition_key=[entra_oid, session_id],
)

ids_to_delete = []
async for page in res.by_page():
async for item in page:
ids_to_delete.append(item["id"])

batch_operations = [("delete", (id,)) for id in ids_to_delete]
await container.execute_item_batch(batch_operations=batch_operations, partition_key=[entra_oid, session_id])
return await make_response("", 204)
except Exception as error:
return error_response(error, f"/chat_history/items/{item_id}")
return error_response(error, f"/chat_history/sessions/{session_id}")


@chat_history_cosmosdb_bp.before_app_serving
Expand Down Expand Up @@ -183,6 +229,7 @@ async def setup_clients():

current_app.config[CONFIG_COSMOS_HISTORY_CLIENT] = cosmos_client
current_app.config[CONFIG_COSMOS_HISTORY_CONTAINER] = cosmos_container
current_app.config[CONFIG_COSMOS_HISTORY_VERSION] = os.environ["AZURE_CHAT_HISTORY_VERSION"]


@chat_history_cosmosdb_bp.after_app_serving
Expand Down
1 change: 1 addition & 0 deletions app/backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@
CONFIG_CHAT_HISTORY_COSMOS_ENABLED = "chat_history_cosmos_enabled"
CONFIG_COSMOS_HISTORY_CLIENT = "cosmos_history_client"
CONFIG_COSMOS_HISTORY_CONTAINER = "cosmos_history_container"
CONFIG_COSMOS_HISTORY_VERSION = "cosmos_history_version"
2 changes: 1 addition & 1 deletion app/backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ azure-core==1.30.2
# msrest
azure-core-tracing-opentelemetry==1.0.0b11
# via azure-monitor-opentelemetry
azure-cosmos==4.7.0
azure-cosmos==4.9.0
# via -r requirements.in
azure-identity==1.17.1
# via
Expand Down
25 changes: 13 additions & 12 deletions app/frontend/src/api/api.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const BACKEND_URI = "";

import { ChatAppResponse, ChatAppResponseOrError, ChatAppRequest, Config, SimpleAPIResponse, HistoryListApiResponse, HistroyApiResponse } from "./models";
import { ChatAppResponse, ChatAppResponseOrError, ChatAppRequest, Config, SimpleAPIResponse, HistoryListApiResponse, HistoryApiResponse } from "./models";
import { useLogin, getToken, isUsingAppServicesLogin } from "../authConfig";

export async function getHeaders(idToken: string | undefined): Promise<Record<string, string>> {
Expand Down Expand Up @@ -145,10 +145,14 @@ export async function postChatHistoryApi(item: any, idToken: string): Promise<an

export async function getChatHistoryListApi(count: number, continuationToken: string | undefined, idToken: string): Promise<HistoryListApiResponse> {
const headers = await getHeaders(idToken);
const response = await fetch("/chat_history/items", {
method: "POST",
headers: { ...headers, "Content-Type": "application/json" },
body: JSON.stringify({ count: count, continuation_token: continuationToken })
let url = `${BACKEND_URI}/chat_history/sessions?count=${count}`;
if (continuationToken) {
url += `&continuationToken=${continuationToken}`;
}

const response = await fetch(url.toString(), {
method: "GET",
headers: { ...headers, "Content-Type": "application/json" }
});

if (!response.ok) {
Expand All @@ -159,9 +163,9 @@ export async function getChatHistoryListApi(count: number, continuationToken: st
return dataResponse;
}

export async function getChatHistoryApi(id: string, idToken: string): Promise<HistroyApiResponse> {
export async function getChatHistoryApi(id: string, idToken: string): Promise<HistoryApiResponse> {
const headers = await getHeaders(idToken);
const response = await fetch(`/chat_history/items/${id}`, {
const response = await fetch(`/chat_history/sessions/${id}`, {
method: "GET",
headers: { ...headers, "Content-Type": "application/json" }
});
Expand All @@ -170,21 +174,18 @@ export async function getChatHistoryApi(id: string, idToken: string): Promise<Hi
throw new Error(`Getting chat history failed: ${response.statusText}`);
}

const dataResponse: HistroyApiResponse = await response.json();
const dataResponse: HistoryApiResponse = await response.json();
return dataResponse;
}

export async function deleteChatHistoryApi(id: string, idToken: string): Promise<any> {
const headers = await getHeaders(idToken);
const response = await fetch(`/chat_history/items/${id}`, {
const response = await fetch(`/chat_history/sessions/${id}`, {
method: "DELETE",
headers: { ...headers, "Content-Type": "application/json" }
});

if (!response.ok) {
throw new Error(`Deleting chat history failed: ${response.statusText}`);
}

const dataResponse: any = await response.json();
return dataResponse;
}
6 changes: 2 additions & 4 deletions app/frontend/src/api/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export interface SpeechConfig {
}

export type HistoryListApiResponse = {
items: {
sessions: {
id: string;
entra_oid: string;
title: string;
Expand All @@ -116,10 +116,8 @@ export type HistoryListApiResponse = {
continuation_token?: string;
};

export type HistroyApiResponse = {
export type HistoryApiResponse = {
id: string;
entra_oid: string;
title: string;
answers: any;
timestamp: number;
};
8 changes: 4 additions & 4 deletions app/frontend/src/components/HistoryProviders/CosmosDB.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ export class CosmosDBProvider implements IHistoryProvider {
if (!this.continuationToken) {
this.isItemEnd = true;
}
return response.items.map(item => ({
id: item.id,
title: item.title,
timestamp: item.timestamp
return response.sessions.map(session => ({
id: session.id,
title: session.title,
timestamp: session.timestamp
}));
} catch (e) {
console.error(e);
Expand Down
Loading
Loading