Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve schema of CosmosDB chat history to handle long conversations #2312

Merged
merged 14 commits into from
Jan 29, 2025
122 changes: 88 additions & 34 deletions app/backend/chat_history/cosmosdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from azure.cosmos.aio import ContainerProxy, CosmosClient
from azure.identity.aio import AzureDeveloperCliCredential, ManagedIdentityCredential
from quart import Blueprint, current_app, jsonify, request
from quart import Blueprint, current_app, jsonify, make_response, request

from config import (
CONFIG_CHAT_HISTORY_COSMOS_ENABLED,
Expand Down Expand Up @@ -34,23 +34,50 @@ async def post_chat_history(auth_claims: Dict[str, Any]):

try:
request_json = await request.get_json()
id = request_json.get("id")
answers = request_json.get("answers")
title = answers[0][0][:50] + "..." if len(answers[0][0]) > 50 else answers[0][0]
session_id = request_json.get("id")
message_pairs = request_json.get("answers")
first_question = message_pairs[0][0]
title = first_question + "..." if len(first_question) > 50 else first_question
timestamp = int(time.time() * 1000)

await container.upsert_item(
{"id": id, "entra_oid": entra_oid, "title": title, "answers": answers, "timestamp": timestamp}
)
# Insert the session item:
session_item = {
pamelafox marked this conversation as resolved.
Show resolved Hide resolved
"id": session_id,
"session_id": session_id,
"entra_oid": entra_oid,
"type": "session",
"title": title,
"timestamp": timestamp,
}

message_pair_items = []
# Now insert a message item for each question/response pair:
for ind, message_pair in enumerate(message_pairs):
message_pair_items.append(
{
"id": f"{session_id}-{ind}",
pamelafox marked this conversation as resolved.
Show resolved Hide resolved
"session_id": session_id,
"entra_oid": entra_oid,
"type": "message_pair",
"question": message_pair[0],
"response": message_pair[1],
"timestamp": None,
}
)

batch_operations = [("upsert", (session_item,))] + [
("upsert", (message_pair_item,)) for message_pair_item in message_pair_items
]

await container.execute_item_batch(batch_operations=batch_operations, partition_key=[entra_oid, session_id])
return jsonify({}), 201
except Exception as error:
return error_response(error, "/chat_history")


@chat_history_cosmosdb_bp.post("/chat_history/items")
@chat_history_cosmosdb_bp.get("/chat_history/sessions")
@authenticated
async def get_chat_history(auth_claims: Dict[str, Any]):
async def get_chat_history_sessions(auth_claims: Dict[str, Any]):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -63,27 +90,26 @@ async def get_chat_history(auth_claims: Dict[str, Any]):
return jsonify({"error": "User OID not found"}), 401

try:
request_json = await request.get_json()
count = request_json.get("count", 20)
continuation_token = request_json.get("continuation_token")
count = int(request.args.get("count", 10))
continuation_token = request.args.get("continuation_token")

res = container.query_items(
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid ORDER BY c.timestamp DESC",
parameters=[dict(name="@entra_oid", value=entra_oid)],
query="SELECT c.id, c.entra_oid, c.title, c.timestamp FROM c WHERE c.entra_oid = @entra_oid AND c.type = @type ORDER BY c.timestamp DESC",
parameters=[dict(name="@entra_oid", value=entra_oid), dict(name="@type", value="session")],
partition_key=[entra_oid],
max_item_count=count,
)

# set the continuation token for the next page
pager = res.by_page(continuation_token)

# Get the first page, and the continuation token
sessions = []
try:
page = await pager.__anext__()
continuation_token = pager.continuation_token # type: ignore

items = []
async for item in page:
items.append(
sessions.append(
{
"id": item.get("id"),
"entra_oid": item.get("entra_oid"),
Expand All @@ -94,18 +120,17 @@ async def get_chat_history(auth_claims: Dict[str, Any]):

# If there are no more pages, StopAsyncIteration is raised
except StopAsyncIteration:
items = []
continuation_token = None

return jsonify({"items": items, "continuation_token": continuation_token}), 200
return jsonify({"sessions": sessions, "continuation_token": continuation_token}), 200

except Exception as error:
return error_response(error, "/chat_history/items")
return error_response(error, "/chat_history/sessions")


@chat_history_cosmosdb_bp.get("/chat_history/items/<item_id>")
@chat_history_cosmosdb_bp.get("/chat_history/sessions/<session_id>")
@authenticated
async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
async def get_chat_history_session(auth_claims: Dict[str, Any], session_id: str):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -118,26 +143,43 @@ async def get_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
return jsonify({"error": "User OID not found"}), 401

try:
res = await container.read_item(item=item_id, partition_key=entra_oid)
res = container.query_items(
query="SELECT * FROM c WHERE c.session_id = @session_id",
parameters=[dict(name="@entra_oid", value=entra_oid), dict(name="@session_id", value=session_id)],
partition_key=[entra_oid, session_id],
)

message_pairs = []
session = None
async for page in res.by_page():
async for item in page:
if item.get("type") == "session":
session = item
elif item.get("type") == "message_pair":
message_pairs.append([item["question"], item["response"]])

if session is None:
return jsonify({"error": "Session not found"}), 404

return (
jsonify(
{
"id": res.get("id"),
"entra_oid": res.get("entra_oid"),
"title": res.get("title", "untitled"),
"timestamp": res.get("timestamp"),
"answers": res.get("answers", []),
"id": session.get("id"),
"entra_oid": entra_oid,
"title": session.get("title"),
"timestamp": session.get("timestamp"),
"answers": message_pairs,
}
),
200,
)
except Exception as error:
return error_response(error, f"/chat_history/items/{item_id}")
return error_response(error, f"/chat_history/sessions/{session_id}")


@chat_history_cosmosdb_bp.delete("/chat_history/items/<item_id>")
@chat_history_cosmosdb_bp.delete("/chat_history/sessions/<session_id>")
@authenticated
async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str):
async def delete_chat_history_session(auth_claims: Dict[str, Any], session_id: str):
if not current_app.config[CONFIG_CHAT_HISTORY_COSMOS_ENABLED]:
return jsonify({"error": "Chat history not enabled"}), 400

Expand All @@ -150,10 +192,22 @@ async def delete_chat_history_session(auth_claims: Dict[str, Any], item_id: str)
return jsonify({"error": "User OID not found"}), 401

try:
await container.delete_item(item=item_id, partition_key=entra_oid)
return jsonify({}), 204
res = container.query_items(
query="SELECT * FROM c WHERE c.session_id = @session_id",
pamelafox marked this conversation as resolved.
Show resolved Hide resolved
parameters=[dict(name="@entra_oid", value=entra_oid), dict(name="@session_id", value=session_id)],
partition_key=[entra_oid, session_id],
)

ids_to_delete = []
async for page in res.by_page():
async for item in page:
ids_to_delete.append(item["id"])

batch_operations = [("delete", (id,)) for id in ids_to_delete]
await container.execute_item_batch(batch_operations=batch_operations, partition_key=[entra_oid, session_id])
return await make_response("", 204)
except Exception as error:
return error_response(error, f"/chat_history/items/{item_id}")
return error_response(error, f"/chat_history/sessions/{session_id}")


@chat_history_cosmosdb_bp.before_app_serving
Expand Down
19 changes: 10 additions & 9 deletions app/frontend/src/api/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,14 @@ export async function postChatHistoryApi(item: any, idToken: string): Promise<an

export async function getChatHistoryListApi(count: number, continuationToken: string | undefined, idToken: string): Promise<HistoryListApiResponse> {
const headers = await getHeaders(idToken);
const response = await fetch("/chat_history/items", {
method: "POST",
headers: { ...headers, "Content-Type": "application/json" },
body: JSON.stringify({ count: count, continuation_token: continuationToken })
let url = `${BACKEND_URI}/chat_history/sessions?count=${count}`;
if (continuationToken) {
url += `&continuationToken=${continuationToken}`;
}

const response = await fetch(url.toString(), {
method: "GET",
headers: { ...headers, "Content-Type": "application/json" }
});

if (!response.ok) {
Expand All @@ -161,7 +165,7 @@ export async function getChatHistoryListApi(count: number, continuationToken: st

export async function getChatHistoryApi(id: string, idToken: string): Promise<HistroyApiResponse> {
const headers = await getHeaders(idToken);
const response = await fetch(`/chat_history/items/${id}`, {
const response = await fetch(`/chat_history/sessions/${id}`, {
method: "GET",
headers: { ...headers, "Content-Type": "application/json" }
});
Expand All @@ -176,15 +180,12 @@ export async function getChatHistoryApi(id: string, idToken: string): Promise<Hi

export async function deleteChatHistoryApi(id: string, idToken: string): Promise<any> {
const headers = await getHeaders(idToken);
const response = await fetch(`/chat_history/items/${id}`, {
const response = await fetch(`/chat_history/sessions/${id}`, {
method: "DELETE",
headers: { ...headers, "Content-Type": "application/json" }
});

if (!response.ok) {
throw new Error(`Deleting chat history failed: ${response.statusText}`);
}

const dataResponse: any = await response.json();
return dataResponse;
}
2 changes: 1 addition & 1 deletion app/frontend/src/api/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export interface SpeechConfig {
}

export type HistoryListApiResponse = {
items: {
sessions: {
id: string;
entra_oid: string;
title: string;
Expand Down
8 changes: 4 additions & 4 deletions app/frontend/src/components/HistoryProviders/CosmosDB.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ export class CosmosDBProvider implements IHistoryProvider {
if (!this.continuationToken) {
this.isItemEnd = true;
}
return response.items.map(item => ({
id: item.id,
title: item.title,
timestamp: item.timestamp
return response.sessions.map(session => ({
id: session.id,
title: session.title,
timestamp: session.timestamp
}));
} catch (e) {
console.error(e);
Expand Down
1 change: 1 addition & 0 deletions app/frontend/src/components/LoginButton/LoginButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export const LoginButton = () => {
})
.catch(error => console.log(error))
.then(async () => {
debugger;
setLoggedIn(await checkLoggedIn(instance));
setUsername((await getUsername(instance)) ?? "");
});
Expand Down
19 changes: 12 additions & 7 deletions infra/main.bicep
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ param cosmosDbLocation string = ''
param cosmosDbAccountName string = ''
param cosmosDbThroughput int = 400
param chatHistoryDatabaseName string = 'chat-database'
param chatHistoryContainerName string = 'chat-history'
param chatHistoryContainerName string = 'chat-history-container'

// https://learn.microsoft.com/azure/ai-services/openai/concepts/models?tabs=python-secure%2Cstandard%2Cstandard-chat-completions#standard-deployment-model-availability
@description('Location for the OpenAI resource group')
Expand Down Expand Up @@ -799,26 +799,31 @@ module cosmosDb 'br/public:avm/res/document-db/database-account:0.6.1' = if (use
containers: [
{
name: chatHistoryContainerName
kind: 'MultiHash'
pamelafox marked this conversation as resolved.
Show resolved Hide resolved
paths: [
'/entra_oid'
'/session_id'
]
indexingPolicy: {
indexingMode: 'consistent'
automatic: true
includedPaths: [
{
path: '/*'
path: '/entra_oid/?'
}
{
path: '/session_id/?'
}
]
excludedPaths: [
{
path: '/title/?'
path: '/timestamp/?'
}
{
path: '/answers/*'
path: '/type/?'
}
]
excludedPaths: [
{
path: '/"_etag"/?'
path: '/*'
}
]
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"answers": [
[
"This is a test message"
"This is a test message",
"This is a test answer"
]
],
"entra_oid": "OID_X",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"continuation_token": "next",
"items": [
"sessions": [
{
"entra_oid": "OID_X",
"id": "123",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"continuation_token": null,
"items": []
"sessions": []
}
Loading
Loading